Skip to content

Commit

Permalink
Fix torch_dtype in lite (#2956)
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan authored Dec 26, 2024
1 parent 3a98ae9 commit 191a7dd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
6 changes: 6 additions & 0 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,13 @@ def calibrate(model: str,
if dtype == 'float16':
model.half()
elif dtype == 'bfloat16':
assert torch.cuda.is_bf16_supported(
), 'your device does not support bfloat16 please set --dtype float16' # noqa
model.to(torch.bfloat16)
elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16:
print('Warning: we cast model to float16 to prevent OOM. You'
' may enforce it bfloat16 by `--dtype bfloat16`')
model.half()
model.eval()

model_type = type(model).__name__
Expand Down
19 changes: 14 additions & 5 deletions lmdeploy/lite/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ def load_hf_from_pretrained(pretrained_model_name_or_path,
dtype: Literal['float16', 'bfloat16',
'auto'], **kwargs):

if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported():
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')

kwargs.pop('config', None)

hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
torch_dtype=dtype,
trust_remote_code=True)

# HACK hard code for qwen, other configs do not have the `fp16` attribute.
Expand All @@ -29,13 +28,23 @@ def load_hf_from_pretrained(pretrained_model_name_or_path,
else:
hf_config.fp16 = True

if dtype != 'auto':
setattr(hf_config, 'torch_dtype', dtype)
torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)
if dtype == 'bfloat16':
torch_dtype = torch.bfloat16
elif dtype == 'float16':
torch_dtype = torch.float16
elif dtype == 'auto' and torch_dtype == torch.bfloat16:
print('Warning: we cast model to float16 to prevent OOM. '
'You may enforce it bfloat16 by `--dtype bfloat16`')
torch_dtype = torch.float16

with LoadNoInit():
# Load model
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, config=hf_config, **kwargs)
pretrained_model_name_or_path,
config=hf_config,
torch_dtype=torch_dtype,
**kwargs)
model.config.use_cache = False

return model

0 comments on commit 191a7dd

Please sign in to comment.