From 2eac533b511a39ab2c1cfcdef6162fcf9a9a5764 Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Thu, 31 Oct 2024 11:41:17 +0800 Subject: [PATCH] fix scan issues Signed-off-by: Zhang, Weiwei1 --- auto_round/autoround.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 8409c3a5..05e29f9a 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -281,7 +281,11 @@ def quantize(self): for block_names in all_blocks: inputs = all_inputs[block_names[0]] all_inputs.pop(block_names[0]) - + keys = inputs.keys() + input_id_str = [key for key in keys if key.startswith('hidden_state')] + if len(input_id_str) != 1: + raise RuntimeError("hidden_states arg mismatch error") + inputs["input_ids"] = inputs.pop(input_id_str[0], None) clear_memory(self.inputs) if "input_ids" in inputs.keys(): @@ -845,6 +849,7 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp gradient_accumulate_steps = self.batch_size ##Force to low gpu batch_size = 1 ##Force to low gpu pick_samples = batch_size * gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) if self.sampler != "rand": whole_indices = torch.randperm(nsamples)[:pick_samples] total_loss = 0 @@ -1079,13 +1084,8 @@ def quant_blocks( clear_memory() for n, m in model.named_parameters(): m.requires_grad_(False) - keys = inputs.keys() - input_id_str = [key for key in keys if key.startswith('hidden_state')] - if len(input_id_str) != 1: - raise RuntimeError("hidden_states arg mismatch error") - input_ids = inputs.pop(input_id_str[0], None) - # input_ids = inputs["input_ids"] - # inputs.pop("input_ids", None) + input_ids = inputs["input_ids"] + inputs.pop("input_ids", None) input_others = inputs clear_memory() input_ids = to_device(input_ids, self.cache_device)