From d6e4c24823c97b2cc6132d32700c56c5ef54e27e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 13:05:54 +0100 Subject: [PATCH] Fix OpenLLaMa default config bug. Skip Mistral test due to lack of GPU memory. --- spacy_llm/models/hf/openllama.py | 4 ++-- spacy_llm/tests/models/test_mistral.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index 34248bc4..8ceb5bbc 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -2,7 +2,7 @@ from confection import SimpleFrozenDict -from ...compat import Literal, torch, transformers +from ...compat import Literal, transformers from ...registry.util import registry from .base import HuggingFace @@ -72,7 +72,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: return ( { **default_cfg_init, - "torch_dtype": torch.float16, + "torch_dtype": "float16", }, {**default_cfg_run, "max_new_tokens": 32}, ) diff --git a/spacy_llm/tests/models/test_mistral.py b/spacy_llm/tests/models/test_mistral.py index 5dde49f0..548d4d29 100644 --- a/spacy_llm/tests/models/test_mistral.py +++ b/spacy_llm/tests/models/test_mistral.py @@ -48,6 +48,7 @@ def test_init(): @pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") def test_init_from_config(): orig_config = Config().from_str(_NLP_CONFIG)