Skip to content

Commit

Permalink
Change device check sequence.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Nov 13, 2023
1 parent 565fbef commit 17a6402
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions spacy_llm/models/hf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,20 @@ def __init__(
self._config_run = {**self._config_run, **config_run}

# `device` and `device_map` are conflicting arguments - ensure they aren't both set.
# Case 1: we have a CUDA GPU (and hence device="cuda:0" by default), but device_map is set by user.
if config_init:
if "device" in default_cfg_init and "device_map" in config_init:
self._config_init.pop("device")
# Case 2: we don't have a CUDA GPU (and hence "device_map=auto" by default), but device is set by user.
if "device_map" in default_cfg_init and "device" in config_init:
self._config_init.pop("device_map")
# Case 3: both explicitly set by user.
# Case 1: both device and device_map explicitly set by user.
if "device" in config_init and "device_map" in config_init:
warnings.warn(
"`device` and `device_map` are conflicting arguments - don't set both. Dropping argument "
"`device`."
)
if "device" in self._config_init:
self._config_init.pop("device")
self._config_init.pop("device")
# Case 2: we have a CUDA GPU (and hence device="cuda:0" by default), but device_map is set by user.
elif "device" in default_cfg_init and "device_map" in config_init:
self._config_init.pop("device")
# Case 3: we don't have a CUDA GPU (and hence "device_map=auto" by default), but device is set by user.
elif "device_map" in default_cfg_init and "device" in config_init:
self._config_init.pop("device_map")

# Fetch proper torch.dtype, if specified.
if (
Expand Down

0 comments on commit 17a6402

Please sign in to comment.