Skip to content

Commit

Permalink
fix bug for layer-wise mode
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Jun 28, 2024
1 parent 200a001 commit 78eb7e0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
1 change: 0 additions & 1 deletion auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ def calib(self, n_samples, bs):
n_samples (int): The number of samples to use for calibration.
bs (int): The number of samples to use for calibration
"""

if isinstance(self.dataset, str):
dataset = self.dataset.replace(" ", "") ##remove all whitespaces
# slow here
Expand Down
6 changes: 4 additions & 2 deletions auto_round/layer_wise/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_pa

def forward_pre_hook(name):
def hook(module, input):
logger.debug(f"{name} forward hood load value")
state_dict = None
if os.path.exists(os.path.join(saved_path, f"{name}.pt")):
state_dict = torch.load(os.path.join(saved_path, f"{name}.pt"))
Expand All @@ -300,6 +301,7 @@ def hook(module, input):

def forward_hook(name):
def hook(module, input, output):
logger.debug(f"{name} forward hood clean value")
if saved_path:
file_path = os.path.join(saved_path, f"{name}.pt")
torch.save(module.state_dict(), file_path)
Expand Down Expand Up @@ -368,7 +370,8 @@ def _update(module):

def _layer_wise_to(module, name, device_or_dtype):
if isinstance(device_or_dtype, torch.dtype):
return module.ori_to(device_or_dtype)
module.ori_to(device_or_dtype)
return module
elif len(module._modules) == 0:
# skip method type
if len(module._parameters) == 0:
Expand Down Expand Up @@ -420,6 +423,5 @@ def load_model_with_hooks(
if saved_path is None:
saved_path = LWQ_WORKSPACE
empty_model = load_empty_model(pretrained_model_name_or_path, cls=cls, **kwargs)
convert_model(empty_model, saved_path=saved_path)
register_weight_hooks(empty_model, empty_model.path, device, clean_weight, saved_path)
return empty_model
17 changes: 15 additions & 2 deletions auto_round/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,15 @@ def unwrapper(self, v, min_scale, max_scale):
max_scale,
self.scale_dtype,
)
if self.orig_layer.device.type == 'meta':
self.orig_layer.to('cpu')
self.orig_layer.weight.data.copy_(q_dq_weight)
self.orig_layer.weight.grad = None ##clear grad
self.orig_layer.scale = scale.to("cpu")
self.orig_layer.zp = zp.to("cpu") if zp is not None else None
if hasattr(self.orig_layer, 'update'):
self.orig_layer.update()
self.orig_layer.to('meta')
return self.orig_layer

def forward(self, x):
Expand All @@ -274,6 +279,8 @@ def forward(self, x):
from torch.functional import F

weight = self.orig_layer.weight
if weight.device.type == 'meta':
weight = self.orig_layer.get_weight()
self.min_scale.data.copy_(torch.clamp(self.min_scale.data, 0, 1.0))
self.max_scale.data.copy_(torch.clamp(self.max_scale.data, 0, 1.0))
weight_q, _, _ = quant_weight(
Expand All @@ -288,7 +295,8 @@ def forward(self, x):
)
weight_q = weight_q.to(weight.dtype)
# pylint: disable=not-callable
return F.linear(x, weight_q, self.orig_layer.bias)
bias = self.orig_layer.get_bias() if hasattr(self.orig_layer, 'get_bias') else self.orig_layer.bias
return F.linear(x, weight_q, bias)


class WrapperTransformerConv1d(torch.nn.Module):
Expand Down Expand Up @@ -323,7 +331,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True):
weight_dtype = self.orig_layer.weight.dtype
weight_dtype = torch.float32
device = self.orig_layer.weight.device
self.weight_t = self.orig_layer.weight.t()
self.weight_t = self.orig_layer.weight.t() if not hasattr(self.orig_layer, 'get_weight') else self.orig_layer.get_weight().t()
self.value = torch.nn.Parameter(
torch.zeros(self.weight_t.shape, device=device, dtype=weight_dtype), requires_grad=True
)
Expand Down Expand Up @@ -356,10 +364,15 @@ def unwrapper(self, v=0, min_scale=1.0, max_scale=1.0):
weight_q, scale, zp = quant_weight(
self.weight_t, self.num_bits, self.group_size, self.sym, v, min_scale, max_scale, self.scale_dtype
)
if self.orig_layer.weight.device.type == 'meta':
self.orig_layer.weight.to('cpu')
self.orig_layer.weight.data.copy_(weight_q.t())
self.orig_layer.weight.grad = None
self.orig_layer.scale = scale.to("cpu")
self.orig_layer.zp = zp.to("cpu")
if hasattr(self.orig_layer, 'update'):
self.orig_layer.update()
self.orig_layer.to('meta')
return self.orig_layer

def forward(self, x):
Expand Down

0 comments on commit 78eb7e0

Please sign in to comment.