Skip to content

Commit

Permalink
revert the change
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 committed Jun 3, 2024
1 parent dd40f17 commit f111d05
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def __init__(self, orig_layer, enable_minmax_tuning=True):
self.scale_dtype = self.orig_layer.scale_dtype
self.sym = self.orig_layer.sym
weight_dtype = self.orig_layer.weight.dtype
weight_dtype = torch.float32
self.value = torch.nn.Parameter(
torch.zeros(self.orig_layer.weight.shape, device=self.orig_layer.weight.device, dtype=weight_dtype),
requires_grad=True,
)
self.enable_minmax_tuning = enable_minmax_tuning
shape = get_scale_shape(self.orig_layer.weight, self.group_size)
weight_dtype = self.orig_layer.weight.dtype
if self.enable_minmax_tuning:
self.min_scale = torch.nn.Parameter(
torch.zeros(shape, device=self.orig_layer.weight.device, dtype=weight_dtype), requires_grad=True
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True):
self.sym = self.orig_layer.sym
self.scale_dtype = self.orig_layer.scale_dtype
weight_dtype = self.orig_layer.weight.dtype

weight_dtype = torch.float32
device = self.orig_layer.weight.device
self.weight_t = self.orig_layer.weight.t()
self.value = torch.nn.Parameter(
Expand Down

0 comments on commit f111d05

Please sign in to comment.