diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index 34248bc4..8ceb5bbc 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -2,7 +2,7 @@ from confection import SimpleFrozenDict -from ...compat import Literal, torch, transformers +from ...compat import Literal, transformers from ...registry.util import registry from .base import HuggingFace @@ -72,7 +72,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: return ( { **default_cfg_init, - "torch_dtype": torch.float16, + "torch_dtype": "float16", }, {**default_cfg_run, "max_new_tokens": 32}, ) diff --git a/spacy_llm/tests/models/test_mistral.py b/spacy_llm/tests/models/test_mistral.py index 5dde49f0..548d4d29 100644 --- a/spacy_llm/tests/models/test_mistral.py +++ b/spacy_llm/tests/models/test_mistral.py @@ -48,6 +48,7 @@ def test_init(): @pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") def test_init_from_config(): orig_config = Config().from_str(_NLP_CONFIG)