Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug at whole block is excluded from quantization #156

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,6 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
output = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, self.cache_device)

if q_input is not None:
for i in range(len(input_ids)):
input_ids[i] = None
input_ids = q_input
torch.cuda.empty_cache()
quantized_layer_names, unquantized_layer_names = wrapper_block(block, self.enable_minmax_tuning)
Expand All @@ -845,6 +843,14 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
else:
optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0)

if len(round_params) + len(minmax_params) <= 0:
dump_info = (
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
f"layers in the block"
)
logger.info(dump_info)
return output, output

if self.lr_scheduler is None:
lr_schedule = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters, verbose=False
Expand Down
23 changes: 23 additions & 0 deletions test/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@ def tearDownClass(self):
shutil.rmtree("./saved", ignore_errors=True)
shutil.rmtree("runs", ignore_errors=True)

def test_remove_whole_block(self):
weight_config={"model.decoder.layers.0.self_attn.k_proj":{"data_type":"float"},
"model.decoder.layers.0.self_attn.v_proj": {"data_type": "float"},
"model.decoder.layers.0.self_attn.q_proj": {"data_type": "float"},
"model.decoder.layers.0.self_attn.out_proj": {"data_type": "float"},
"model.decoder.layers.0.fc1": {"data_type": "float"},
"model.decoder.layers.0.fc2": {"data_type": "float"},
}
bits, group_size, sym = 4, 128, False
autoround = AutoRound(
self.model,
self.tokenizer,
bits=bits,
group_size=group_size,
sym=sym,
iters=2,
seqlen=2,
dataset=self.llm_dataloader,
weight_config=weight_config
)
autoround.quantize()


def test_default(self):
bits, group_size, sym = 4, 128, False
autoround = AutoRound(
Expand Down