Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 committed Jun 26, 2024
1 parent 48244dc commit f50db32
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,14 +832,14 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
Returns:
Tuple: (q_outputs, output) if self.enable_quanted_input is True, else (None, output)
"""
logger.info(f"Start to quant block with algo: {algo}")
logger.info(f">>>>> Start to quant block with algo: {algo}")

output = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, self.cache_device)

if q_input is not None:
input_ids = q_input
torch.cuda.empty_cache()
# quantized_layer_names, unquantized_layer_names = wrapper_block(block, self.enable_minmax_tuning, self.enable_teq, algo=algo)

quantized_layer_names, unquantized_layer_names = quantizer.wrapper_block_entry(block, self.enable_minmax_tuning, algo=algo)


Expand All @@ -863,7 +863,8 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
from auto_round import teq
teq_params_lst = teq.get_scale_param_from_block(block)
trainable_params.append({"params": teq_params_lst})

else:
raise NotImplementedError(f"Unsupported algo: {algo}")

optimizer = self.optimizer(params=trainable_params, lr=self.lr, weight_decay=0)

Expand Down Expand Up @@ -1054,6 +1055,8 @@ def quant_blocks(
m = m.to(device)

algo_lst = [utils.AlgoEnum.Rounding, utils.AlgoEnum.TEQ] if self.enable_teq else [utils.AlgoEnum.Rounding]
logger.info(f"Apply quantization for block {n} with algos: {algo_lst}")

for algo in algo_lst:
q_input, input_ids = self.quant_block(
m,
Expand Down

0 comments on commit f50db32

Please sign in to comment.