diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5955b5e..b217490 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -172,9 +172,10 @@ def _initialize_scale_zero_point_observer( # (output_channels, 1) expected_shape = (weight_shape[0], 1) elif quantization_args.strategy == QuantizationStrategy.GROUP: + num_groups = weight_shape[1] // quantization_args.group_size expected_shape = ( weight_shape[0], - weight_shape[1] // quantization_args.group_size, + 1 if num_groups < 1 else num_groups ) scale_dtype = module.weight.dtype