From 161d14a2902a15c996759b6411c22e9b35680ef6 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 11:59:42 +0100 Subject: [PATCH] Fix default dtype. --- spacy_llm/models/hf/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 46cc5942..8e3e55b6 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -124,7 +124,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run: Dict[str, Any] = {} if has_torch: - default_cfg_init["torch_dtype"] = torch.bfloat16 + default_cfg_init["torch_dtype"] = "bfloat16" if has_torch_cuda_gpu: # this ensures it fails explicitely when GPU is not enabled or sufficient default_cfg_init["device"] = "cuda:0"