Skip to content

Commit

Permalink
Fix verify_device_map (#1842)
Browse files Browse the repository at this point in the history
* make verify_device_map return True only if device map has more that 1 element

* Fix style and comment

* fix style
  • Loading branch information
Rexhaif authored Aug 14, 2023
1 parent 6458058 commit f67e11a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,7 +3059,9 @@ def verify_device_map(self, model: torch.nn.Module) -> bool:
"""
Verifies that `model` has not been prepared with big model inference with a device-map resembling `auto`.
"""
# Checks if any of the child module has the attribute `hf_device_map`.
has_hf_device_map = any(hasattr(m, "hf_device_map") for m in model.modules())
# Checks if any of the child modules has the attribute `hf_device_map` and this map has more than one entry.
for m in model.modules():
if hasattr(m, "hf_device_map") and len(m.hf_device_map) > 1:
return True

return has_hf_device_map
return False

0 comments on commit f67e11a

Please sign in to comment.