Skip to content

Commit

Permalink
Fix GPU tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Apr 20, 2024
1 parent a5109e2 commit f25092d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ filterwarnings = [
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
"ignore:^.*was deprecated in langchain-community.*",
"ignore:^.*was deprecated in LangChain 0.0.1.*"
"ignore:^.*was deprecated in LangChain 0.0.1.*",
"ignore:^.*the load_module() method is deprecated and slated for removal in Python 3.12.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ langchain>=0.1,<0.2; python_version>="3.9"
openai>=0.27,<=0.28.1; python_version>="3.9"

# Necessary for running all local models on GPU.
transformers[sentencepiece]>=4.0.0
# TODO: transformers > 4.38 causes bug in model handling due to unknown factors. To be investigated.
transformers[sentencepiece]>=4.0.0,<=4.38
torch
einops>=0.4

Expand Down
6 changes: 5 additions & 1 deletion spacy_llm/tests/models/test_dolly.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import warnings

import pytest
import spacy
Expand Down Expand Up @@ -42,7 +43,9 @@
def test_init():
"""Test initialization and simple run."""
nlp = spacy.blank("en")
nlp.add_pipe("llm", config=_PIPE_CFG)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
nlp.add_pipe("llm", config=_PIPE_CFG)
doc = nlp("This is a test.")
nlp.get_pipe("llm")._model.get_model_names()
torch.cuda.empty_cache()
Expand All @@ -53,6 +56,7 @@ def test_init():

@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
@pytest.mark.filterwarnings("ignore:the load_module() method is deprecated")
def test_init_from_config():
orig_config = Config().from_str(_NLP_CONFIG)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
Expand Down
2 changes: 2 additions & 0 deletions spacy_llm/tests/models/test_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
@pytest.mark.filterwarnings("ignore:the load_module() method is deprecated")
def test_init():
"""Test initialization and simple run."""
nlp = spacy.blank("en")
Expand All @@ -53,6 +54,7 @@ def test_init():

@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
@pytest.mark.filterwarnings("ignore:the load_module() method is deprecated")
def test_init_from_config():
orig_config = Config().from_str(_NLP_CONFIG)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
Expand Down

0 comments on commit f25092d

Please sign in to comment.