From 191a7ddda019aa8a7c37eb0160e3c71dfcdabdc5 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Thu, 26 Dec 2024 14:42:47 +0800 Subject: [PATCH] Fix torch_dtype in lite (#2956) --- lmdeploy/lite/apis/calibrate.py | 6 ++++++ lmdeploy/lite/utils/load.py | 19 ++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 307cf6d7e9..85467997e3 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -262,7 +262,13 @@ def calibrate(model: str, if dtype == 'float16': model.half() elif dtype == 'bfloat16': + assert torch.cuda.is_bf16_supported( + ), 'your device does not support bfloat16 please set --dtype float16' # noqa model.to(torch.bfloat16) + elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. You' + ' may enforce it bfloat16 by `--dtype bfloat16`') + model.half() model.eval() model_type = type(model).__name__ diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py index 170c149778..ac4519371a 100644 --- a/lmdeploy/lite/utils/load.py +++ b/lmdeploy/lite/utils/load.py @@ -12,14 +12,13 @@ def load_hf_from_pretrained(pretrained_model_name_or_path, dtype: Literal['float16', 'bfloat16', 'auto'], **kwargs): - if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): raise RuntimeError('Your device does not supports bf16(bfloat16), ' 'please change to fp16(float16)') kwargs.pop('config', None) hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, - torch_dtype=dtype, trust_remote_code=True) # HACK hard code for qwen, other configs do not have the `fp16` attribute. @@ -29,13 +28,23 @@ def load_hf_from_pretrained(pretrained_model_name_or_path, else: hf_config.fp16 = True - if dtype != 'auto': - setattr(hf_config, 'torch_dtype', dtype) + torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) + if dtype == 'bfloat16': + torch_dtype = torch.bfloat16 + elif dtype == 'float16': + torch_dtype = torch.float16 + elif dtype == 'auto' and torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. ' + 'You may enforce it bfloat16 by `--dtype bfloat16`') + torch_dtype = torch.float16 with LoadNoInit(): # Load model model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, config=hf_config, **kwargs) + pretrained_model_name_or_path, + config=hf_config, + torch_dtype=torch_dtype, + **kwargs) model.config.use_cache = False return model