Skip to content

Commit

Permalink
Fix OpenLLaMa default config bug. Skip Mistral test due to lack of GP…
Browse files Browse the repository at this point in the history
…U memory.
  • Loading branch information
rmitsch committed Nov 13, 2023
1 parent af64c71 commit d6e4c24
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions spacy_llm/models/hf/openllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from confection import SimpleFrozenDict

from ...compat import Literal, torch, transformers
from ...compat import Literal, transformers
from ...registry.util import registry
from .base import HuggingFace

Expand Down Expand Up @@ -72,7 +72,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
return (
{
**default_cfg_init,
"torch_dtype": torch.float16,
"torch_dtype": "float16",
},
{**default_cfg_run, "max_new_tokens": 32},
)
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_init():


@pytest.mark.gpu
@pytest.mark.skip(reason="CI runner needs more GPU memory")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init_from_config():
orig_config = Config().from_str(_NLP_CONFIG)
Expand Down

0 comments on commit d6e4c24

Please sign in to comment.