From f688e6ceb805e66b55df76736c0feae3ce37a491 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 9 Nov 2023 17:04:42 +0100 Subject: [PATCH 01/10] Consider device/device_map conflicts. Always move inputs to model device for models using AutoModelForCausalLM. --- spacy_llm/models/hf/base.py | 22 +++++++++++++++++++++- spacy_llm/models/hf/falcon.py | 1 - spacy_llm/models/hf/mistral.py | 21 ++++++--------------- spacy_llm/models/hf/openllama.py | 15 ++++++++------- spacy_llm/models/hf/stablelm.py | 13 ++++++------- 5 files changed, 41 insertions(+), 31 deletions(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 71fdc074..45756a59 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -27,12 +27,31 @@ def __init__( inference_config (Dict[Any, Any]): HF config for model run. """ self._name = name if self.hf_account in name else f"{self.hf_account}/{name}" - self._config_init, self._config_run = self.compile_default_configs() + default_cfg_init, default_cfg_run = self.compile_default_configs() + self._config_init, self._config_run = default_cfg_init, default_cfg_run + if config_init: self._config_init = {**self._config_init, **config_init} if config_run: self._config_run = {**self._config_run, **config_run} + # `device` and `device_map` are conflicting arguments - ensure they aren't both set. + # Case 1: we have a CUDA GPU (and hence device="cuda:0" by default), but device_map is set by user. + if config_init: + if "device" in default_cfg_init and "device_map" in config_init: + self._config_init.pop("device") + # Case 2: we don't have a CUDA GPU (and hence "device_map=auto" by default), but device is set by user. + if "device_map" in default_cfg_init and "device" in config_init: + self._config_init.pop("device_map") + # Case 3: both explicitly set by user. + if "device" in config_init and "device_map" in config_init: + warnings.warn( + "`device` and `device_map` are conflicting arguments - don't set both. Dropping argument " + "`device`." + ) + if "device" in self._config_init: + self._config_init.pop("device") + # Init HF model. HuggingFace.check_installation() self._check_model() @@ -106,6 +125,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: "Install CUDA to load and run the LLM on the GPU, or install 'accelerate' to dynamically " "distribute the LLM on the CPU or even the hard disk. The latter may be slow." ) + return default_cfg_init, default_cfg_run @abc.abstractmethod diff --git a/spacy_llm/models/hf/falcon.py b/spacy_llm/models/hf/falcon.py index 76d4e9e2..2e18ac9d 100644 --- a/spacy_llm/models/hf/falcon.py +++ b/spacy_llm/models/hf/falcon.py @@ -19,7 +19,6 @@ def __init__( config_run: Optional[Dict[str, Any]], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - self._device: Optional[str] = None super().__init__(name=name, config_init=config_init, config_run=config_run) assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index 6fe78c78..56ae7be3 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -17,7 +17,6 @@ def __init__( config_run: Optional[Dict[str, Any]], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - self._device: Optional[str] = None self._is_instruct = "instruct" in name super().__init__(name=name, config_init=config_init, config_run=config_run) @@ -33,14 +32,15 @@ def __init__( def init_model(self) -> Any: self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name) init_cfg = self._config_init + device: Optional[str] = None if "device" in init_cfg: - self._device = init_cfg.pop("device") + device = init_cfg.pop("device") model = transformers.AutoModelForCausalLM.from_pretrained( self._name, **init_cfg, resume_download=True ) - if self._device: - model.to(self._device) + if device: + model.to(device) return model @@ -61,8 +61,7 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove ).input_ids for prompt in prompts ] - if self._device: - tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids] + tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids] return [ self._tokenizer.decode( @@ -74,14 +73,6 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove for tok_ii in tokenized_input_ids ] - @staticmethod - def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: - default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() - return ( - default_cfg_init, - default_cfg_run, - ) - @registry.llm_models("spacy.Mistral.v1") def mistral_hf( diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index 4cf2f4cf..34248bc4 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -22,7 +22,6 @@ def __init__( config_run: Optional[Dict[str, Any]], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - self._device: Optional[str] = None super().__init__(name=name, config_init=config_init, config_run=config_run) def init_model(self) -> "transformers.AutoModelForCausalLM": @@ -32,14 +31,15 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": # Initialize tokenizer and model. self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name) init_cfg = self._config_init + device: Optional[str] = None if "device" in init_cfg: - self._device = init_cfg.pop("device") + device = init_cfg.pop("device") + model = transformers.AutoModelForCausalLM.from_pretrained( self._name, **init_cfg ) - - if self._device: - model.to(self._device) + if device: + model.to(device) return model @@ -48,8 +48,9 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove tokenized_input_ids = [ self._tokenizer(prompt, return_tensors="pt").input_ids for prompt in prompts ] - if self._device: - tokenized_input_ids = [tii.to(self._device) for tii in tokenized_input_ids] + tokenized_input_ids = [ + tii.to(self._model.device) for tii in tokenized_input_ids + ] assert hasattr(self._model, "generate") return [ diff --git a/spacy_llm/models/hf/stablelm.py b/spacy_llm/models/hf/stablelm.py index 4711d69f..34698e0e 100644 --- a/spacy_llm/models/hf/stablelm.py +++ b/spacy_llm/models/hf/stablelm.py @@ -42,7 +42,6 @@ def __init__( ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None self._is_tuned = "tuned" in name - self._device: Optional[str] = None super().__init__(name=name, config_init=config_init, config_run=config_run) def init_model(self) -> "transformers.AutoModelForCausalLM": @@ -51,14 +50,15 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": """ self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name) init_cfg = self._config_init + device: Optional[str] = None if "device" in init_cfg: - self._device = init_cfg.pop("device") + device = init_cfg.pop("device") + model = transformers.AutoModelForCausalLM.from_pretrained( self._name, **init_cfg ) - - if self._device: - model.half().to(self._device) + if device: + model.half().to(device) return model @@ -80,8 +80,7 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove ] ) ] - if self._device: - tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids] + tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids] assert hasattr(self._model, "generate") return [ From 139d2b305542579dcaa8b5c04f02ea52adc23bc1 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 10 Nov 2023 10:28:29 +0100 Subject: [PATCH 02/10] Add test for device conflict checks. --- requirements-dev.txt | 2 ++ spacy_llm/tests/models/test_hf.py | 39 +++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 spacy_llm/tests/models/test_hf.py diff --git a/requirements-dev.txt b/requirements-dev.txt index 49239d6e..ee9c0c07 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,6 +16,8 @@ openai>=0.27,<=0.28.1; python_version>="3.9" transformers[sentencepiece]>=4.0.0 torch einops>=0.4 +# For testing device mapping. +accelerate # Necessary for pytest checks and ignores. sqlalchemy; python_version<"3.9" diff --git a/spacy_llm/tests/models/test_hf.py b/spacy_llm/tests/models/test_hf.py new file mode 100644 index 00000000..935b2aae --- /dev/null +++ b/spacy_llm/tests/models/test_hf.py @@ -0,0 +1,39 @@ +from typing import Tuple + +import pytest +import spacy +from thinc.compat import has_torch_cuda_gpu + +_PIPE_CFG = { + "model": { + "@llm_models": "", + "name": "", + }, + "task": {"@llm_tasks": "spacy.NoOp.v1"}, + "save_io": True, +} + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +@pytest.mark.parametrize( + "model", (("spacy.Dolly.v1", "dolly-v2-3b"), ("spacy.Llama2.v1", "Llama-2-7b-hf")) +) +def test_device_config_conflict(model: Tuple[str, str]): + """Test initialization and simple run.""" + nlp = spacy.blank("en") + model, name = model + cfg = {**_PIPE_CFG, **{"model": {"@llm_models": model, "name": name}}} + + # Set device only. + cfg["model"]["config_init"] = {"device": "cpu"} # type: ignore[index] + nlp.add_pipe("llm", name="llm1", config=cfg) + + # Set device_map only. + cfg["model"]["config_init"] = {"device_map": "auto"} # type: ignore[index] + nlp.add_pipe("llm", name="llm2", config=cfg) + + # Set device_map and device. + cfg["model"]["config_init"] = {"device_map": "auto", "device": "cpu"} # type: ignore[index] + with pytest.warns(UserWarning, match="conflicting arguments"): + nlp.add_pipe("llm", name="llm3", config=cfg) From ad9583451bdbab68bb7157b25203820e1bfcc6f4 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 10 Nov 2023 12:59:59 +0100 Subject: [PATCH 03/10] Remove accelerate dependency. --- requirements-dev.txt | 2 -- spacy_llm/tests/models/test_hf.py | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index ee9c0c07..49239d6e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,8 +16,6 @@ openai>=0.27,<=0.28.1; python_version>="3.9" transformers[sentencepiece]>=4.0.0 torch einops>=0.4 -# For testing device mapping. -accelerate # Necessary for pytest checks and ignores. sqlalchemy; python_version<"3.9" diff --git a/spacy_llm/tests/models/test_hf.py b/spacy_llm/tests/models/test_hf.py index 935b2aae..5164eeb0 100644 --- a/spacy_llm/tests/models/test_hf.py +++ b/spacy_llm/tests/models/test_hf.py @@ -4,6 +4,8 @@ import spacy from thinc.compat import has_torch_cuda_gpu +from spacy_llm.compat import has_accelerate + _PIPE_CFG = { "model": { "@llm_models": "", @@ -31,9 +33,17 @@ def test_device_config_conflict(model: Tuple[str, str]): # Set device_map only. cfg["model"]["config_init"] = {"device_map": "auto"} # type: ignore[index] - nlp.add_pipe("llm", name="llm2", config=cfg) + if has_accelerate: + nlp.add_pipe("llm", name="llm2", config=cfg) + else: + with pytest.raises(ImportError, match="requires Accelerate"): + nlp.add_pipe("llm", name="llm2", config=cfg) # Set device_map and device. cfg["model"]["config_init"] = {"device_map": "auto", "device": "cpu"} # type: ignore[index] with pytest.warns(UserWarning, match="conflicting arguments"): - nlp.add_pipe("llm", name="llm3", config=cfg) + if has_accelerate: + nlp.add_pipe("llm", name="llm3", config=cfg) + else: + with pytest.raises(ImportError, match="requires Accelerate"): + nlp.add_pipe("llm", name="llm3", config=cfg) From 9f3cca37ac1dcec74f053311d719f102bdd6499c Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 10 Nov 2023 16:25:04 +0100 Subject: [PATCH 04/10] Add torch_dtype handling. --- spacy_llm/models/hf/base.py | 12 ++++++++++++ spacy_llm/tests/models/test_hf.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 45756a59..71a3cafc 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -52,6 +52,18 @@ def __init__( if "device" in self._config_init: self._config_init.pop("device") + # Fetch proper torch.dtype, if specified. + if has_torch and self._config_init.get("torch_dtype", "") not in ("", "auto"): + try: + self._config_init["torch_dtype"] = getattr( + torch, self._config_init["torch_dtype"] + ) + except AttributeError as ex: + raise ValueError( + f"Invalid value {self._config_init['torch_dtype']} was specified for `torch_dtype`. " + f"Double-check you specified a valid dtype." + ) from ex + # Init HF model. HuggingFace.check_installation() self._check_model() diff --git a/spacy_llm/tests/models/test_hf.py b/spacy_llm/tests/models/test_hf.py index 5164eeb0..1b785808 100644 --- a/spacy_llm/tests/models/test_hf.py +++ b/spacy_llm/tests/models/test_hf.py @@ -6,6 +6,8 @@ from spacy_llm.compat import has_accelerate +from ...compat import torch + _PIPE_CFG = { "model": { "@llm_models": "", @@ -22,7 +24,7 @@ "model", (("spacy.Dolly.v1", "dolly-v2-3b"), ("spacy.Llama2.v1", "Llama-2-7b-hf")) ) def test_device_config_conflict(model: Tuple[str, str]): - """Test initialization and simple run.""" + """Test device configuration.""" nlp = spacy.blank("en") model, name = model cfg = {**_PIPE_CFG, **{"model": {"@llm_models": model, "name": name}}} @@ -47,3 +49,28 @@ def test_device_config_conflict(model: Tuple[str, str]): else: with pytest.raises(ImportError, match="requires Accelerate"): nlp.add_pipe("llm", name="llm3", config=cfg) + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_torch_dtype(): + """Test torch_dtype setting.""" + nlp = spacy.blank("en") + cfg = { + **_PIPE_CFG, + **{"model": {"@llm_models": "spacy.Dolly.v1", "name": "dolly-v2-3b"}}, + } + + # Should be converted to torch.float16. + cfg["model"]["config_init"] = {"torch_dtype": "float16"} # type: ignore[index] + llm = nlp.add_pipe("llm", name="llm1", config=cfg) + assert llm._model._config_init["torch_dtype"] == torch.float16 + + # Should remain "auto". + cfg["model"]["config_init"] = {"torch_dtype": "auto"} # type: ignore[index] + nlp.add_pipe("llm", name="llm2", config=cfg) + + # Should fail - nonexistent dtype. + cfg["model"]["config_init"] = {"torch_dtype": "float999"} # type: ignore[index] + with pytest.raises(ValueError, match="Invalid value float999"): + nlp.add_pipe("llm", name="llm3", config=cfg) From 05cffab151450432865b41d1d5abec4b9a424d2c Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 11:39:40 +0100 Subject: [PATCH 05/10] Fix torch_dtype check. --- spacy_llm/models/hf/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 71a3cafc..46cc5942 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -53,7 +53,11 @@ def __init__( self._config_init.pop("device") # Fetch proper torch.dtype, if specified. - if has_torch and self._config_init.get("torch_dtype", "") not in ("", "auto"): + if ( + has_torch + and "torch_dtype" in self._config_init + and self._config_init["torch_dtype"] != "auto" + ): try: self._config_init["torch_dtype"] = getattr( torch, self._config_init["torch_dtype"] From 161d14a2902a15c996759b6411c22e9b35680ef6 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 11:59:42 +0100 Subject: [PATCH 06/10] Fix default dtype. --- spacy_llm/models/hf/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 46cc5942..8e3e55b6 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -124,7 +124,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run: Dict[str, Any] = {} if has_torch: - default_cfg_init["torch_dtype"] = torch.bfloat16 + default_cfg_init["torch_dtype"] = "bfloat16" if has_torch_cuda_gpu: # this ensures it fails explicitely when GPU is not enabled or sufficient default_cfg_init["device"] = "cuda:0" From af64c712f0be7e2b5f176914198088e39ace592f Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 12:40:00 +0100 Subject: [PATCH 07/10] Empty cache in HF tests. --- spacy_llm/tests/models/test_hf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spacy_llm/tests/models/test_hf.py b/spacy_llm/tests/models/test_hf.py index 1b785808..306e879a 100644 --- a/spacy_llm/tests/models/test_hf.py +++ b/spacy_llm/tests/models/test_hf.py @@ -50,6 +50,8 @@ def test_device_config_conflict(model: Tuple[str, str]): with pytest.raises(ImportError, match="requires Accelerate"): nlp.add_pipe("llm", name="llm3", config=cfg) + torch.cuda.empty_cache() + @pytest.mark.gpu @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") @@ -74,3 +76,5 @@ def test_torch_dtype(): cfg["model"]["config_init"] = {"torch_dtype": "float999"} # type: ignore[index] with pytest.raises(ValueError, match="Invalid value float999"): nlp.add_pipe("llm", name="llm3", config=cfg) + + torch.cuda.empty_cache() From d6e4c24823c97b2cc6132d32700c56c5ef54e27e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 13:05:54 +0100 Subject: [PATCH 08/10] Fix OpenLLaMa default config bug. Skip Mistral test due to lack of GPU memory. --- spacy_llm/models/hf/openllama.py | 4 ++-- spacy_llm/tests/models/test_mistral.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index 34248bc4..8ceb5bbc 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -2,7 +2,7 @@ from confection import SimpleFrozenDict -from ...compat import Literal, torch, transformers +from ...compat import Literal, transformers from ...registry.util import registry from .base import HuggingFace @@ -72,7 +72,7 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: return ( { **default_cfg_init, - "torch_dtype": torch.float16, + "torch_dtype": "float16", }, {**default_cfg_run, "max_new_tokens": 32}, ) diff --git a/spacy_llm/tests/models/test_mistral.py b/spacy_llm/tests/models/test_mistral.py index 5dde49f0..548d4d29 100644 --- a/spacy_llm/tests/models/test_mistral.py +++ b/spacy_llm/tests/models/test_mistral.py @@ -48,6 +48,7 @@ def test_init(): @pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") def test_init_from_config(): orig_config = Config().from_str(_NLP_CONFIG) From 565fbefc00484c0c79a9e47daeb41624b21051d3 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 14:52:08 +0100 Subject: [PATCH 09/10] Update spacy_llm/tests/models/test_hf.py Co-authored-by: Sofie Van Landeghem --- spacy_llm/tests/models/test_hf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spacy_llm/tests/models/test_hf.py b/spacy_llm/tests/models/test_hf.py index 306e879a..3058035c 100644 --- a/spacy_llm/tests/models/test_hf.py +++ b/spacy_llm/tests/models/test_hf.py @@ -4,9 +4,7 @@ import spacy from thinc.compat import has_torch_cuda_gpu -from spacy_llm.compat import has_accelerate - -from ...compat import torch +from spacy_llm.compat import has_accelerate, torch _PIPE_CFG = { "model": { From 17a6402c2f2a230a046a95a59637a4e0fd30e363 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 15:00:22 +0100 Subject: [PATCH 10/10] Change device check sequence. --- spacy_llm/models/hf/base.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 8e3e55b6..87209118 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -36,21 +36,20 @@ def __init__( self._config_run = {**self._config_run, **config_run} # `device` and `device_map` are conflicting arguments - ensure they aren't both set. - # Case 1: we have a CUDA GPU (and hence device="cuda:0" by default), but device_map is set by user. if config_init: - if "device" in default_cfg_init and "device_map" in config_init: - self._config_init.pop("device") - # Case 2: we don't have a CUDA GPU (and hence "device_map=auto" by default), but device is set by user. - if "device_map" in default_cfg_init and "device" in config_init: - self._config_init.pop("device_map") - # Case 3: both explicitly set by user. + # Case 1: both device and device_map explicitly set by user. if "device" in config_init and "device_map" in config_init: warnings.warn( "`device` and `device_map` are conflicting arguments - don't set both. Dropping argument " "`device`." ) - if "device" in self._config_init: - self._config_init.pop("device") + self._config_init.pop("device") + # Case 2: we have a CUDA GPU (and hence device="cuda:0" by default), but device_map is set by user. + elif "device" in default_cfg_init and "device_map" in config_init: + self._config_init.pop("device") + # Case 3: we don't have a CUDA GPU (and hence "device_map=auto" by default), but device is set by user. + elif "device_map" in default_cfg_init and "device" in config_init: + self._config_init.pop("device_map") # Fetch proper torch.dtype, if specified. if (