diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index 88135208..e91be263 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -60,7 +60,7 @@ def convert(args): args.model, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()} config = model.config.to_dict() if args.quantize: