Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expected Tensor argument scales to have dtype torch.bfloat16, but got torch.float32 instead #876

Open
agunapal opened this issue Sep 11, 2024 · 1 comment

Comments

@agunapal
Copy link

Getting this error with int4 quantization.

May be a noob question: Is this a bug or does int4 require the weights to be in bfloat16?

Traceback (most recent call last):
  File "/home/agunapal/torch_ao/vit_ao.py", line 16, in <module>
    quantize_(model, int4_weight_only())
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 463, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  [Previous line repeated 2 more times]
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 199, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 393, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 553, in apply_int4_weight_only_quant
    return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 286, in from_hp_to_intx
    layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 1033, in from_plain
    scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 319, in pack_tinygemm_scales_and_zeros
    guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 128, in guard_dtype_size
    raise ValueError(f"Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
ValueError: Expected Tensor argument scales to have dtype torch.bfloat16, but got torch.float32 instead.

Code for repro:

import torch
import torchao

from torchvision.models import vit_b_16, ViT_B_16_Weights 
from torchao.utils import benchmark_model
from torchao.quantization import int4_weight_only, quantize_

torch.set_float32_matmul_precision('high')

dtype  = torch.float32
device = "cuda"
N = 1

model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
quantize_(model, int4_weight_only())
model = torch.compile(model, mode='max-autotune').to(device).to(dtype)
method = "int8 quantize followed by compile"
input = (torch.randn(N, 3, 224, 224).to(device).to(dtype),)

with torch.no_grad():
    # warmup
    benchmark_model(model, 20, input)
    # benchmark
    result.append((method, N, benchmark_model(model, 100, input)))



for (method, N, elapsed_time) in result:
    print(f"batch_size={N} : elapsed time {elapsed_time:.3f} ms :  {method} ")

@jerryzh168
Copy link
Contributor

yeah int4_weight_only quant requires bfloat16 right now I think since that's the only dtype support for the tinygemm kernel (int4_weight_only is actually corresponding to just int4 tinygemm kernel, it's not a general int4 weight only quant)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants