From 78eb7e01d04899ae8e5f60a8403eeb23105aec32 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Thu, 27 Jun 2024 21:07:07 -0400 Subject: [PATCH] fix bug for layer-wise mode Signed-off-by: n1ck-guo --- auto_round/autoround.py | 1 - auto_round/layer_wise/utils.py | 6 ++++-- auto_round/quantizer.py | 17 +++++++++++++++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index c79358b9..0e950774 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -459,7 +459,6 @@ def calib(self, n_samples, bs): n_samples (int): The number of samples to use for calibration. bs (int): The number of samples to use for calibration """ - if isinstance(self.dataset, str): dataset = self.dataset.replace(" ", "") ##remove all whitespaces # slow here diff --git a/auto_round/layer_wise/utils.py b/auto_round/layer_wise/utils.py index 0571d8b1..f909ce1e 100644 --- a/auto_round/layer_wise/utils.py +++ b/auto_round/layer_wise/utils.py @@ -285,6 +285,7 @@ def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_pa def forward_pre_hook(name): def hook(module, input): + logger.debug(f"{name} forward hood load value") state_dict = None if os.path.exists(os.path.join(saved_path, f"{name}.pt")): state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) @@ -300,6 +301,7 @@ def hook(module, input): def forward_hook(name): def hook(module, input, output): + logger.debug(f"{name} forward hood clean value") if saved_path: file_path = os.path.join(saved_path, f"{name}.pt") torch.save(module.state_dict(), file_path) @@ -368,7 +370,8 @@ def _update(module): def _layer_wise_to(module, name, device_or_dtype): if isinstance(device_or_dtype, torch.dtype): - return module.ori_to(device_or_dtype) + module.ori_to(device_or_dtype) + return module elif len(module._modules) == 0: # skip method type if len(module._parameters) == 0: @@ -420,6 +423,5 @@ def load_model_with_hooks( if saved_path is None: saved_path = LWQ_WORKSPACE empty_model = load_empty_model(pretrained_model_name_or_path, cls=cls, **kwargs) - convert_model(empty_model, saved_path=saved_path) register_weight_hooks(empty_model, empty_model.path, device, clean_weight, saved_path) return empty_model \ No newline at end of file diff --git a/auto_round/quantizer.py b/auto_round/quantizer.py index 5fce47d4..5b750eb8 100644 --- a/auto_round/quantizer.py +++ b/auto_round/quantizer.py @@ -256,10 +256,15 @@ def unwrapper(self, v, min_scale, max_scale): max_scale, self.scale_dtype, ) + if self.orig_layer.device.type == 'meta': + self.orig_layer.to('cpu') self.orig_layer.weight.data.copy_(q_dq_weight) self.orig_layer.weight.grad = None ##clear grad self.orig_layer.scale = scale.to("cpu") self.orig_layer.zp = zp.to("cpu") if zp is not None else None + if hasattr(self.orig_layer, 'update'): + self.orig_layer.update() + self.orig_layer.to('meta') return self.orig_layer def forward(self, x): @@ -274,6 +279,8 @@ def forward(self, x): from torch.functional import F weight = self.orig_layer.weight + if weight.device.type == 'meta': + weight = self.orig_layer.get_weight() self.min_scale.data.copy_(torch.clamp(self.min_scale.data, 0, 1.0)) self.max_scale.data.copy_(torch.clamp(self.max_scale.data, 0, 1.0)) weight_q, _, _ = quant_weight( @@ -288,7 +295,8 @@ def forward(self, x): ) weight_q = weight_q.to(weight.dtype) # pylint: disable=not-callable - return F.linear(x, weight_q, self.orig_layer.bias) + bias = self.orig_layer.get_bias() if hasattr(self.orig_layer, 'get_bias') else self.orig_layer.bias + return F.linear(x, weight_q, bias) class WrapperTransformerConv1d(torch.nn.Module): @@ -323,7 +331,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True): weight_dtype = self.orig_layer.weight.dtype weight_dtype = torch.float32 device = self.orig_layer.weight.device - self.weight_t = self.orig_layer.weight.t() + self.weight_t = self.orig_layer.weight.t() if not hasattr(self.orig_layer, 'get_weight') else self.orig_layer.get_weight().t() self.value = torch.nn.Parameter( torch.zeros(self.weight_t.shape, device=device, dtype=weight_dtype), requires_grad=True ) @@ -356,10 +364,15 @@ def unwrapper(self, v=0, min_scale=1.0, max_scale=1.0): weight_q, scale, zp = quant_weight( self.weight_t, self.num_bits, self.group_size, self.sym, v, min_scale, max_scale, self.scale_dtype ) + if self.orig_layer.weight.device.type == 'meta': + self.orig_layer.weight.to('cpu') self.orig_layer.weight.data.copy_(weight_q.t()) self.orig_layer.weight.grad = None self.orig_layer.scale = scale.to("cpu") self.orig_layer.zp = zp.to("cpu") + if hasattr(self.orig_layer, 'update'): + self.orig_layer.update() + self.orig_layer.to('meta') return self.orig_layer def forward(self, x):