diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 71fdc074..87209118 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -27,12 +27,46 @@ def __init__( inference_config (Dict[Any, Any]): HF config for model run. """ self._name = name if self.hf_account in name else f"{self.hf_account}/{name}" - self._config_init, self._config_run = self.compile_default_configs() + default_cfg_init, default_cfg_run = self.compile_default_configs() + self._config_init, self._config_run = default_cfg_init, default_cfg_run + if config_init: self._config_init = {**self._config_init, **config_init} if config_run: self._config_run = {**self._config_run, **config_run} + # `device` and `device_map` are conflicting arguments - ensure they aren't both set. + if config_init: + # Case 1: both device and device_map explicitly set by user. + if "device" in config_init and "device_map" in config_init: + warnings.warn( + "`device` and `device_map` are conflicting arguments - don't set both. Dropping argument " + "`device`." + ) + self._config_init.pop("device") + # Case 2: we have a CUDA GPU (and hence device="cuda:0" by default), but device_map is set by user. + elif "device" in default_cfg_init and "device_map" in config_init: + self._config_init.pop("device") + # Case 3: we don't have a CUDA GPU (and hence "device_map=auto" by default), but device is set by user. + elif "device_map" in default_cfg_init and "device" in config_init: + self._config_init.pop("device_map") + + # Fetch proper torch.dtype, if specified. + if ( + has_torch + and "torch_dtype" in self._config_init + and self._config_init["torch_dtype"] != "auto" + ): + try: + self._config_init["torch_dtype"] = getattr( + torch, self._config_init["torch_dtype"] + ) + except AttributeError as ex: + raise ValueError( + f"Invalid value {self._config_init['torch_dtype']} was specified for `torch_dtype`. " + f"Double-check you specified a valid dtype." + ) from ex + # Init HF model. HuggingFace.check_installation() self._check_model() @@ -89,7 +123,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run: Dict[str, Any] = {} if has_torch: - default_cfg_init["torch_dtype"] = torch.bfloat16 + default_cfg_init["torch_dtype"] = "bfloat16" if has_torch_cuda_gpu: # this ensures it fails explicitely when GPU is not enabled or sufficient default_cfg_init["device"] = "cuda:0" @@ -106,6 +140,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: "Install CUDA to load and run the LLM on the GPU, or install 'accelerate' to dynamically " "distribute the LLM on the CPU or even the hard disk. The latter may be slow." ) + return default_cfg_init, default_cfg_run @abc.abstractmethod diff --git a/spacy_llm/models/hf/falcon.py b/spacy_llm/models/hf/falcon.py index 76d4e9e2..2e18ac9d 100644 --- a/spacy_llm/models/hf/falcon.py +++ b/spacy_llm/models/hf/falcon.py @@ -19,7 +19,6 @@ def __init__( config_run: Optional[Dict[str, Any]], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - self._device: Optional[str] = None super().__init__(name=name, config_init=config_init, config_run=config_run) assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index 6fe78c78..56ae7be3 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -17,7 +17,6 @@ def __init__( config_run: Optional[Dict[str, Any]], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - self._device: Optional[str] = None self._is_instruct = "instruct" in name super().__init__(name=name, config_init=config_init, config_run=config_run) @@ -33,14 +32,15 @@ def __init__( def init_model(self) -> Any: self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name) init_cfg = self._config_init + device: Optional[str] = None if "device" in init_cfg: - self._device = init_cfg.pop("device") + device = init_cfg.pop("device") model = transformers.AutoModelForCausalLM.from_pretrained( self._name, **init_cfg, resume_download=True ) - if self._device: - model.to(self._device) + if device: + model.to(device) return model @@ -61,8 +61,7 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove ).input_ids for prompt in prompts ] - if self._device: - tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids] + tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids] return [ self._tokenizer.decode( @@ -74,14 +73,6 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove for tok_ii in tokenized_input_ids ] - @staticmethod - def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: - default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() - return ( - default_cfg_init, - default_cfg_run, - ) - @registry.llm_models("spacy.Mistral.v1") def mistral_hf( diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index 4cf2f4cf..8ceb5bbc 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -2,7 +2,7 @@ from confection import SimpleFrozenDict -from ...compat import Literal, torch, transformers +from ...compat import Literal, transformers from ...registry.util import registry from .base import HuggingFace @@ -22,7 +22,6 @@ def __init__( config_run: Optional[Dict[str, Any]], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - self._device: Optional[str] = None super().__init__(name=name, config_init=config_init, config_run=config_run) def init_model(self) -> "transformers.AutoModelForCausalLM": @@ -32,14 +31,15 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": # Initialize tokenizer and model. self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name) init_cfg = self._config_init + device: Optional[str] = None if "device" in init_cfg: - self._device = init_cfg.pop("device") + device = init_cfg.pop("device") + model = transformers.AutoModelForCausalLM.from_pretrained( self._name, **init_cfg ) - - if self._device: - model.to(self._device) + if device: + model.to(device) return model @@ -48,8 +48,9 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove tokenized_input_ids = [ self._tokenizer(prompt, return_tensors="pt").input_ids for prompt in prompts ] - if self._device: - tokenized_input_ids = [tii.to(self._device) for tii in tokenized_input_ids] + tokenized_input_ids = [ + tii.to(self._model.device) for tii in tokenized_input_ids + ] assert hasattr(self._model, "generate") return [ @@ -71,7 +72,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: return ( { **default_cfg_init, - "torch_dtype": torch.float16, + "torch_dtype": "float16", }, {**default_cfg_run, "max_new_tokens": 32}, ) diff --git a/spacy_llm/models/hf/stablelm.py b/spacy_llm/models/hf/stablelm.py index 4711d69f..34698e0e 100644 --- a/spacy_llm/models/hf/stablelm.py +++ b/spacy_llm/models/hf/stablelm.py @@ -42,7 +42,6 @@ def __init__( ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None self._is_tuned = "tuned" in name - self._device: Optional[str] = None super().__init__(name=name, config_init=config_init, config_run=config_run) def init_model(self) -> "transformers.AutoModelForCausalLM": @@ -51,14 +50,15 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": """ self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name) init_cfg = self._config_init + device: Optional[str] = None if "device" in init_cfg: - self._device = init_cfg.pop("device") + device = init_cfg.pop("device") + model = transformers.AutoModelForCausalLM.from_pretrained( self._name, **init_cfg ) - - if self._device: - model.half().to(self._device) + if device: + model.half().to(device) return model @@ -80,8 +80,7 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove ] ) ] - if self._device: - tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids] + tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids] assert hasattr(self._model, "generate") return [ diff --git a/spacy_llm/tests/models/test_hf.py b/spacy_llm/tests/models/test_hf.py new file mode 100644 index 00000000..3058035c --- /dev/null +++ b/spacy_llm/tests/models/test_hf.py @@ -0,0 +1,78 @@ +from typing import Tuple + +import pytest +import spacy +from thinc.compat import has_torch_cuda_gpu + +from spacy_llm.compat import has_accelerate, torch + +_PIPE_CFG = { + "model": { + "@llm_models": "", + "name": "", + }, + "task": {"@llm_tasks": "spacy.NoOp.v1"}, + "save_io": True, +} + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +@pytest.mark.parametrize( + "model", (("spacy.Dolly.v1", "dolly-v2-3b"), ("spacy.Llama2.v1", "Llama-2-7b-hf")) +) +def test_device_config_conflict(model: Tuple[str, str]): + """Test device configuration.""" + nlp = spacy.blank("en") + model, name = model + cfg = {**_PIPE_CFG, **{"model": {"@llm_models": model, "name": name}}} + + # Set device only. + cfg["model"]["config_init"] = {"device": "cpu"} # type: ignore[index] + nlp.add_pipe("llm", name="llm1", config=cfg) + + # Set device_map only. + cfg["model"]["config_init"] = {"device_map": "auto"} # type: ignore[index] + if has_accelerate: + nlp.add_pipe("llm", name="llm2", config=cfg) + else: + with pytest.raises(ImportError, match="requires Accelerate"): + nlp.add_pipe("llm", name="llm2", config=cfg) + + # Set device_map and device. + cfg["model"]["config_init"] = {"device_map": "auto", "device": "cpu"} # type: ignore[index] + with pytest.warns(UserWarning, match="conflicting arguments"): + if has_accelerate: + nlp.add_pipe("llm", name="llm3", config=cfg) + else: + with pytest.raises(ImportError, match="requires Accelerate"): + nlp.add_pipe("llm", name="llm3", config=cfg) + + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_torch_dtype(): + """Test torch_dtype setting.""" + nlp = spacy.blank("en") + cfg = { + **_PIPE_CFG, + **{"model": {"@llm_models": "spacy.Dolly.v1", "name": "dolly-v2-3b"}}, + } + + # Should be converted to torch.float16. + cfg["model"]["config_init"] = {"torch_dtype": "float16"} # type: ignore[index] + llm = nlp.add_pipe("llm", name="llm1", config=cfg) + assert llm._model._config_init["torch_dtype"] == torch.float16 + + # Should remain "auto". + cfg["model"]["config_init"] = {"torch_dtype": "auto"} # type: ignore[index] + nlp.add_pipe("llm", name="llm2", config=cfg) + + # Should fail - nonexistent dtype. + cfg["model"]["config_init"] = {"torch_dtype": "float999"} # type: ignore[index] + with pytest.raises(ValueError, match="Invalid value float999"): + nlp.add_pipe("llm", name="llm3", config=cfg) + + torch.cuda.empty_cache() diff --git a/spacy_llm/tests/models/test_mistral.py b/spacy_llm/tests/models/test_mistral.py index 5dde49f0..548d4d29 100644 --- a/spacy_llm/tests/models/test_mistral.py +++ b/spacy_llm/tests/models/test_mistral.py @@ -48,6 +48,7 @@ def test_init(): @pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") def test_init_from_config(): orig_config = Config().from_str(_NLP_CONFIG)