Skip to content

Commit

Permalink
Raise errors on all ranks for checkpoint download failures (#3345)
Browse files Browse the repository at this point in the history
Co-authored-by: Ning Wang <[email protected]>
  • Loading branch information
irenedea and bigning authored May 31, 2024
1 parent a6c4e43 commit 3c0a817
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
all_file_paths = dist.all_gather_object(relative_file_paths)

# 2. Download to the destination all files this rank needs if on first replica
download_error = False
if first_replica:
log.debug(f'Rank {dist.get_global_rank()} starting to download files.')

Expand Down Expand Up @@ -275,12 +276,26 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
download_object_or_file(object_name, file_destination, self.object_store)
log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')
except Exception as e:
# PyTorch will capture any exception of this function,
# and dist.all_gather_objects(exception) before raising it.
# If that all_gather_objects fails, the exception is never visible to user.
# We immediately print the exception to avoid that situation.
log.error(f'Exception {type(e)} raised during downloading: {str(e)}')
raise e
download_error = True

# PyTorch will capture any exception of this function,
# and dist.all_gather_objects(exception) before raising it.
# If that all_gather_objects fails, the exception is never visible to user.
# We raise the exception from all ranks to ensure the user sees it.
download_error_tensor = dist.get_device(None).tensor_to_device(torch.tensor(1 if download_error else 0))
error_by_rank = dist.all_gather(download_error_tensor)
failed_ranks = []
for rank, error in enumerate(list(error_by_rank)):
if error > 0:
failed_ranks.append(rank)
download_error = True

if download_error:
raise RuntimeError(
f'Ranks {failed_ranks} failed to download.',
'To see the full error please look at the logs for that rank, which are logged via log.error.',
)

# 3. Wait for all ranks to finish.
log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')
Expand Down

0 comments on commit 3c0a817

Please sign in to comment.