Skip to content

Commit

Permalink
manager: rename start_step to start_quorum and move step changes to s…
Browse files Browse the repository at this point in the history
…hould_commit (#46)
  • Loading branch information
d4l3k authored Dec 18, 2024
1 parent 78c5721 commit a484e4f
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 47 deletions.
70 changes: 68 additions & 2 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ impl ManagerService for Arc<Manager> {
.await
.map_err(|e| Status::internal(e.to_string()))?;

let participants = &quorum.participants;
let mut participants = quorum.participants.clone();
participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id));

let mut replica_rank = 10000000000;
for (i, p) in participants.iter().enumerate() {
Expand All @@ -266,7 +267,7 @@ impl ManagerService for Arc<Manager> {
// Decide whether we should be healing:
// 1. if we're not at the max step
// 2. if everyone is at the first step and we're not the primary
let heal = max_step != req.step || max_step == 1 && primary.replica_id != self.replica_id;
let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id;
if heal {
info!(
"healing is required step={}, max_step={}",
Expand Down Expand Up @@ -475,4 +476,69 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_get_quorum_heal_first_step() -> Result<()> {
let lighthouse = Lighthouse::new(LighthouseOpt {
bind: "[::]:0".to_string(),
join_timeout_ms: 100,
min_replicas: 2,
quorum_tick_ms: 100,
})
.await?;
let lighthouse_fut = tokio::spawn(lighthouse.clone().run());

let mut manager_futs: Vec<tokio::task::JoinHandle<Result<ManagerQuorumResponse>>> =
Vec::new();

for replica_id in 0..2 {
let lighthouse_addr = lighthouse.address();
manager_futs.push(tokio::spawn(async move {
let manager = Manager::new(
format!("rep_{}", replica_id),
lighthouse_addr,
"addr".to_string(),
"[::]:0".to_string(),
"store_addr".to_string(),
1, // world size
)
.await?;
let manager_fut = tokio::spawn(manager.clone().run());

let mut client =
manager_client_new(manager.address(), Duration::from_secs(10)).await?;

let request = tonic::Request::new(ManagerQuorumRequest {
rank: 0,
step: 0,
checkpoint_server_addr: "addr".to_string(),
});

let result = client.quorum(request).await?.into_inner();

manager_fut.abort();

Ok(result)
}));
}

let resp_a = manager_futs.swap_remove(0).await??;
let resp_b = manager_futs.swap_remove(0).await??;

lighthouse_fut.abort();

assert_eq!(resp_a.quorum_id, 1);
assert_eq!(resp_a.max_step, 0);
assert_eq!(resp_a.replica_rank, 0);
assert_eq!(resp_a.replica_world_size, 2);
assert_eq!(resp_a.heal, false);

assert_eq!(resp_b.quorum_id, 1);
assert_eq!(resp_b.max_step, 0);
assert_eq!(resp_b.replica_rank, 1);
assert_eq!(resp_b.replica_world_size, 2);
assert_eq!(resp_b.heal, true);

Ok(())
}
}
46 changes: 32 additions & 14 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def _manager_state_dict() -> Dict[str, T]:
self._batches_committed = 0

# first step is 1
self._should_step = True
self._participating_rank: Optional[int] = None
self._participating_world_size: int = 0

Expand Down Expand Up @@ -218,8 +217,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
fut.set_result(grad)
return fut

assert self._quorum_future is not None, "must call step before allreduce_grad"
self._quorum_future.result()
self.wait_quorum()

if not self.is_participating():
grad.zero_()
Expand Down Expand Up @@ -315,21 +313,28 @@ def callback(
self._pending_work.append(cast(torch.futures.Future[object], fut))
return fut

def start_step(self) -> None:
def start_quorum(self, allow_heal: bool = True) -> None:
"""
.. note::
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
Computes a new quorum (potentially asynchronously) and readies the
manager for a new step.
Must be called before the forwards pass of each step for best
It's best practice to call this before the forwards pass of each step for
performance as computing quorum may take some time.
If allow_heal is set, the manager will attempt to heal either
synchronously before returning or asynchronously prior to any network
calls.
Args:
allow_heal: whether to allow healing at the beginning of the step
"""

if self._should_step:
self._step += 1
self._batches_committed += self.num_participants()
# wait for previous quorum to complete
if self._quorum_future is not None:
self._quorum_future.result()

self._errored = None
self._healing = False
Expand All @@ -338,9 +343,9 @@ def start_step(self) -> None:
# TODO: we should really be wrapping this whole section in a try-except
# block to allow gracefully recovering from issues in PG setup and quorum.

self._quorum_future = self._executor.submit(self._async_quorum)
self._quorum_future = self._executor.submit(self._async_quorum, allow_heal)
if not self._use_async_quorum:
self._quorum_future.result()
self.wait_quorum()

if self._healing:
# eagerly apply pending state_dict so we can run the forwards pass
Expand All @@ -350,7 +355,18 @@ def start_step(self) -> None:
# and don't need to zero_grad
self._healing = False

def _async_quorum(self) -> None:
def wait_quorum(self) -> None:
"""
Wait for the quorum to complete.
ProcessGroup will be in a healthy state after this returns.
"""
assert (
self._quorum_future is not None
), "must call start_quorum before wait_quorum"
self._quorum_future.result()

def _async_quorum(self, allow_heal: bool) -> None:
(
quorum_id,
replica_rank,
Expand All @@ -372,7 +388,7 @@ def _async_quorum(self) -> None:
# workers will be healthy.
self._participating_rank, self._participating_world_size = (
(max_rank, max_world_size)
if self._use_async_quorum
if self._use_async_quorum or not allow_heal
else (replica_rank, replica_world_size)
)

Expand All @@ -397,7 +413,7 @@ def _async_quorum(self) -> None:
self._quorum_id = quorum_id

# See manager.rs for healing conditions
if heal:
if heal and allow_heal:
self._healing = True
self._logger.info(
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
Expand Down Expand Up @@ -475,7 +491,9 @@ def should_commit(self) -> bool:
self._ckpt_server.disallow_checkpoint()

# decide whether we're in a healthy state to increase the step count
self._should_step = should_commit
if should_commit:
self._step += 1
self._batches_committed += self.num_participants()

return should_commit

Expand Down
Loading

0 comments on commit a484e4f

Please sign in to comment.