From 9afd956f521503d1747886c4e2ca5bb1b0bcc41c Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 3 Oct 2024 10:18:44 -0400 Subject: [PATCH] Update number of groups (#178) * set num_groups to 1 if if < 1 * Update src/compressed_tensors/quantization/lifecycle/initialize.py Co-authored-by: Kyle Sayers --------- Co-authored-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5955b5e..49e7b1a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -172,9 +172,10 @@ def _initialize_scale_zero_point_observer( # (output_channels, 1) expected_shape = (weight_shape[0], 1) elif quantization_args.strategy == QuantizationStrategy.GROUP: + num_groups = weight_shape[1] // quantization_args.group_size expected_shape = ( weight_shape[0], - weight_shape[1] // quantization_args.group_size, + max(num_groups, 1) ) scale_dtype = module.weight.dtype