From a02bb2b14fdf696f2c66f417ce43ced70cea1404 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 3 Oct 2024 01:43:42 +0000 Subject: [PATCH] set num_groups to 1 if if < 1 --- src/compressed_tensors/quantization/lifecycle/initialize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5955b5e..b217490 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -172,9 +172,10 @@ def _initialize_scale_zero_point_observer( # (output_channels, 1) expected_shape = (weight_shape[0], 1) elif quantization_args.strategy == QuantizationStrategy.GROUP: + num_groups = weight_shape[1] // quantization_args.group_size expected_shape = ( weight_shape[0], - weight_shape[1] // quantization_args.group_size, + 1 if num_groups < 1 else num_groups ) scale_dtype = module.weight.dtype