Skip to content

Commit

Permalink
Fix default dtype.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Nov 13, 2023
1 parent 05cffab commit 161d14a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion spacy_llm/models/hf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
default_cfg_run: Dict[str, Any] = {}

if has_torch:
default_cfg_init["torch_dtype"] = torch.bfloat16
default_cfg_init["torch_dtype"] = "bfloat16"
if has_torch_cuda_gpu:
# this ensures it fails explicitely when GPU is not enabled or sufficient
default_cfg_init["device"] = "cuda:0"
Expand Down

0 comments on commit 161d14a

Please sign in to comment.