diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index aa1214d5e..62f2051b0 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -238,7 +238,6 @@ def kwargs_prefix(prefix, weight_kwargs): 'softmax_input_quant': None, 'attn_output_weights_quant': sym_act_quant, 'attn_output_weights_bit_width': act_bit_width, - 'attn_output_weights_signed': not uint_sym_act_for_unsigned_values, 'q_scaled_quant': sym_act_quant, 'q_scaled_bit_width': act_bit_width, 'k_transposed_quant': sym_act_quant, @@ -275,22 +274,23 @@ def kwargs_prefix(prefix, weight_kwargs): act_quant_and_bit_width = {'act_quant': act_quant, 'bit_width': act_bit_width} quant_act_kwargs = {**act_quant_and_bit_width, 'return_quant_tensor': True} + unsigned_quant_act_kwargs = quant_act_kwargs.copy() + if uint_sym_act_for_unsigned_values: + quant_mha_kwargs['attn_output_weights_signed'] = False + unsigned_quant_act_kwargs['signed'] = False + quant_act_map = { - torch.nn.ReLU: - (qnn.QuantReLU, { - **quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}), - torch.nn.ReLU6: - (qnn.QuantReLU, { - **quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}), - torch.nn.Sigmoid: ( - qnn.QuantSigmoid, { - **quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}),} + torch.nn.ReLU: (qnn.QuantReLU, { + **unsigned_quant_act_kwargs}), + torch.nn.ReLU6: (qnn.QuantReLU, { + **unsigned_quant_act_kwargs}), + torch.nn.Sigmoid: (qnn.QuantSigmoid, { + **unsigned_quant_act_kwargs}),} quant_identity_map = { 'signed': (qnn.QuantIdentity, { **quant_act_kwargs}), - 'unsigned': ( - qnn.QuantIdentity, { - **quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}),} + 'unsigned': (qnn.QuantIdentity, { + **unsigned_quant_act_kwargs}),} quant_layerwise_layer_map = { torch.nn.Linear: (qnn.QuantLinear, layerwise_quant_wbiol_kwargs), torch.nn.MultiheadAttention: (qnn.QuantMultiheadAttention, layerwise_quant_mha_kwargs),