Skip to content

Commit

Permalink
Better error when a bad directory is given for weight merging (huggin…
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Jul 12, 2024
1 parent 12a007d commit e1247de
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,29 @@ def merge_fsdp_weights(
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
"""
checkpoint_dir = Path(checkpoint_dir)
from accelerate.state import PartialState

if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")

# Verify that the checkpoint directory exists
if not checkpoint_dir.exists():
model_path_exists = (checkpoint_dir / "pytorch_model_fsdp_0").exists()
optimizer_path_exists = (checkpoint_dir / "optimizer_0").exists()
err = f"Tried to load from {checkpoint_dir} but couldn't find a valid metadata file."
if model_path_exists and optimizer_path_exists:
err += " However, potential model and optimizer checkpoint directories exist."
err += f"Please pass in either {checkpoint_dir}/pytorch_model_fsdp_0 or {checkpoint_dir}/optimizer_0"
err += "instead."
elif model_path_exists:
err += " However, a potential model checkpoint directory exists."
err += f"Please try passing in {checkpoint_dir}/pytorch_model_fsdp_0 instead."
elif optimizer_path_exists:
err += " However, a potential optimizer checkpoint directory exists."
err += f"Please try passing in {checkpoint_dir}/optimizer_0 instead."
raise ValueError(err)

# To setup `save` to work
state = PartialState()
if state.is_main_process:
Expand Down

0 comments on commit e1247de

Please sign in to comment.