Skip to content

Commit

Permalink
fix bug at whole block is excluded from quantization (intel#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Jun 17, 2024
1 parent 1c3a92d commit a5c322a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
10 changes: 8 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,6 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
output = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, self.cache_device)

if q_input is not None:
for i in range(len(input_ids)):
input_ids[i] = None
input_ids = q_input
torch.cuda.empty_cache()
quantized_layer_names, unquantized_layer_names = wrapper_block(block, self.enable_minmax_tuning)
Expand All @@ -845,6 +843,14 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
else:
optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0)

if len(round_params) + len(minmax_params) <= 0:
dump_info = (
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
f"layers in the block"
)
logger.info(dump_info)
return output, output

if self.lr_scheduler is None:
lr_schedule = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters, verbose=False
Expand Down
23 changes: 23 additions & 0 deletions test/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@ def tearDownClass(self):
shutil.rmtree("./saved", ignore_errors=True)
shutil.rmtree("runs", ignore_errors=True)

def test_remove_whole_block(self):
weight_config={"model.decoder.layers.0.self_attn.k_proj":{"data_type":"float"},
"model.decoder.layers.0.self_attn.v_proj": {"data_type": "float"},
"model.decoder.layers.0.self_attn.q_proj": {"data_type": "float"},
"model.decoder.layers.0.self_attn.out_proj": {"data_type": "float"},
"model.decoder.layers.0.fc1": {"data_type": "float"},
"model.decoder.layers.0.fc2": {"data_type": "float"},
}
bits, group_size, sym = 4, 128, False
autoround = AutoRound(
self.model,
self.tokenizer,
bits=bits,
group_size=group_size,
sym=sym,
iters=2,
seqlen=2,
dataset=self.llm_dataloader,
weight_config=weight_config
)
autoround.quantize()


def test_default(self):
bits, group_size, sym = 4, 128, False
autoround = AutoRound(
Expand Down

0 comments on commit a5c322a

Please sign in to comment.