diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 6c3d6638..127ca52c 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -70,13 +70,13 @@ def __init__(self, orig_layer, enable_minmax_tuning=True): self.scale_dtype = self.orig_layer.scale_dtype self.sym = self.orig_layer.sym weight_dtype = self.orig_layer.weight.dtype + weight_dtype = torch.float32 self.value = torch.nn.Parameter( torch.zeros(self.orig_layer.weight.shape, device=self.orig_layer.weight.device, dtype=weight_dtype), requires_grad=True, ) self.enable_minmax_tuning = enable_minmax_tuning shape = get_scale_shape(self.orig_layer.weight, self.group_size) - weight_dtype = self.orig_layer.weight.dtype if self.enable_minmax_tuning: self.min_scale = torch.nn.Parameter( torch.zeros(shape, device=self.orig_layer.weight.device, dtype=weight_dtype), requires_grad=True @@ -177,7 +177,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True): self.sym = self.orig_layer.sym self.scale_dtype = self.orig_layer.scale_dtype weight_dtype = self.orig_layer.weight.dtype - + weight_dtype = torch.float32 device = self.orig_layer.weight.device self.weight_t = self.orig_layer.weight.t() self.value = torch.nn.Parameter(