From 063009bc015704aceb818054a3720248a4a0a7c5 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Mon, 3 Jun 2024 14:29:00 +0800 Subject: [PATCH 1/4] set the trainable parameters to 16 bits --- auto_round/autoround.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 375da000..6c3d6638 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -70,7 +70,6 @@ def __init__(self, orig_layer, enable_minmax_tuning=True): self.scale_dtype = self.orig_layer.scale_dtype self.sym = self.orig_layer.sym weight_dtype = self.orig_layer.weight.dtype - weight_dtype = torch.float32 ##TODO revert the change to check the accuracy self.value = torch.nn.Parameter( torch.zeros(self.orig_layer.weight.shape, device=self.orig_layer.weight.device, dtype=weight_dtype), requires_grad=True, @@ -178,7 +177,6 @@ def __init__(self, orig_layer, enable_minmax_tuning=True): self.sym = self.orig_layer.sym self.scale_dtype = self.orig_layer.scale_dtype weight_dtype = self.orig_layer.weight.dtype - weight_dtype = torch.float32 ##TODO revert the change to check the accuracy device = self.orig_layer.weight.device self.weight_t = self.orig_layer.weight.t() From dd40f1750bfe157af918d67b46d4fa2b7222f636 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Mon, 3 Jun 2024 16:12:57 +0800 Subject: [PATCH 2/4] fix bugs --- auto_round/auto_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 84aaa5e7..838d13c0 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -311,7 +311,7 @@ def convert_model(self, model: nn.Module): return model def _dynamic_import_inference_linear(self, bits, backend): - if bits == 4 and self.exllama2_available and "exllama2" in backend: + if bits == 4 and self.exllama2_available and "exllamav2" in backend: from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear else: from auto_round_extension.cuda.qliner_triton import QuantLinear From f111d05bb148c800470e013f4203a54d5e61ae64 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Mon, 3 Jun 2024 16:16:11 +0800 Subject: [PATCH 3/4] revert the change --- auto_round/autoround.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 6c3d6638..127ca52c 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -70,13 +70,13 @@ def __init__(self, orig_layer, enable_minmax_tuning=True): self.scale_dtype = self.orig_layer.scale_dtype self.sym = self.orig_layer.sym weight_dtype = self.orig_layer.weight.dtype + weight_dtype = torch.float32 self.value = torch.nn.Parameter( torch.zeros(self.orig_layer.weight.shape, device=self.orig_layer.weight.device, dtype=weight_dtype), requires_grad=True, ) self.enable_minmax_tuning = enable_minmax_tuning shape = get_scale_shape(self.orig_layer.weight, self.group_size) - weight_dtype = self.orig_layer.weight.dtype if self.enable_minmax_tuning: self.min_scale = torch.nn.Parameter( torch.zeros(shape, device=self.orig_layer.weight.device, dtype=weight_dtype), requires_grad=True @@ -177,7 +177,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True): self.sym = self.orig_layer.sym self.scale_dtype = self.orig_layer.scale_dtype weight_dtype = self.orig_layer.weight.dtype - + weight_dtype = torch.float32 device = self.orig_layer.weight.device self.weight_t = self.orig_layer.weight.t() self.value = torch.nn.Parameter( From d11b267b3b259475b81b16983aeca24efe4437c6 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Mon, 3 Jun 2024 16:17:00 +0800 Subject: [PATCH 4/4] fix the bug --- auto_round/auto_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 838d13c0..cb7ebee0 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -255,7 +255,7 @@ class AutoRoundQuantizer(HfQuantizer): def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): super().__init__(quantization_config, **kwargs) - self.exllama2_available = is_autoround_exllamav2_available + self.exllama2_available = is_autoround_exllamav2_available() def validate_environment(self, *args, **kwargs): if not is_auto_round_available():