diff --git a/spacy_llm/tests/models/test_hf.py b/spacy_llm/tests/models/test_hf.py index 1b785808..306e879a 100644 --- a/spacy_llm/tests/models/test_hf.py +++ b/spacy_llm/tests/models/test_hf.py @@ -50,6 +50,8 @@ def test_device_config_conflict(model: Tuple[str, str]): with pytest.raises(ImportError, match="requires Accelerate"): nlp.add_pipe("llm", name="llm3", config=cfg) + torch.cuda.empty_cache() + @pytest.mark.gpu @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") @@ -74,3 +76,5 @@ def test_torch_dtype(): cfg["model"]["config_init"] = {"torch_dtype": "float999"} # type: ignore[index] with pytest.raises(ValueError, match="Invalid value float999"): nlp.add_pipe("llm", name="llm3", config=cfg) + + torch.cuda.empty_cache()