Skip to content

Commit

Permalink
fix scan issues
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 committed Oct 31, 2024
1 parent bcd6138 commit 2eac533
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,11 @@ def quantize(self):
for block_names in all_blocks:
inputs = all_inputs[block_names[0]]
all_inputs.pop(block_names[0])

keys = inputs.keys()
input_id_str = [key for key in keys if key.startswith('hidden_state')]
if len(input_id_str) != 1:
raise RuntimeError("hidden_states arg mismatch error")
inputs["input_ids"] = inputs.pop(input_id_str[0], None)
clear_memory(self.inputs)

if "input_ids" in inputs.keys():
Expand Down Expand Up @@ -845,6 +849,7 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp
gradient_accumulate_steps = self.batch_size ##Force to low gpu
batch_size = 1 ##Force to low gpu
pick_samples = batch_size * gradient_accumulate_steps
pick_samples = min(nsamples, pick_samples)
if self.sampler != "rand":
whole_indices = torch.randperm(nsamples)[:pick_samples]
total_loss = 0
Expand Down Expand Up @@ -1079,13 +1084,8 @@ def quant_blocks(
clear_memory()
for n, m in model.named_parameters():
m.requires_grad_(False)
keys = inputs.keys()
input_id_str = [key for key in keys if key.startswith('hidden_state')]
if len(input_id_str) != 1:
raise RuntimeError("hidden_states arg mismatch error")
input_ids = inputs.pop(input_id_str[0], None)
# input_ids = inputs["input_ids"]
# inputs.pop("input_ids", None)
input_ids = inputs["input_ids"]
inputs.pop("input_ids", None)
input_others = inputs
clear_memory()
input_ids = to_device(input_ids, self.cache_device)
Expand Down

0 comments on commit 2eac533

Please sign in to comment.