From 447abb2b0f4f0d1474652fc39be7a68ed1512ae9 Mon Sep 17 00:00:00 2001 From: Yifan Date: Mon, 25 Dec 2023 22:10:01 +0800 Subject: [PATCH] QWEN: Fix unsupported ScalarType BFloat16 (#187) Fix unsupported ScalarType BFloat16. --- llms/qwen/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index 881352086..e91be2638 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -60,7 +60,7 @@ def convert(args): args.model, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()} config = model.config.to_dict() if args.quantize: