Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 2, 2023
1 parent 5f2790a commit d21ae11
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def kwargs_prefix(prefix, weight_kwargs):
'softmax_input_quant': None,
'attn_output_weights_quant': sym_act_quant,
'attn_output_weights_bit_width': act_bit_width,
'attn_output_weights_signed': not uint_sym_act_for_unsigned_values,
'q_scaled_quant': sym_act_quant,
'q_scaled_bit_width': act_bit_width,
'k_transposed_quant': sym_act_quant,
Expand Down Expand Up @@ -275,22 +274,23 @@ def kwargs_prefix(prefix, weight_kwargs):

act_quant_and_bit_width = {'act_quant': act_quant, 'bit_width': act_bit_width}
quant_act_kwargs = {**act_quant_and_bit_width, 'return_quant_tensor': True}
unsigned_quant_act_kwargs = quant_act_kwargs.copy()
if uint_sym_act_for_unsigned_values:
quant_mha_kwargs['attn_output_weights_signed'] = False
unsigned_quant_act_kwargs['signed'] = False

quant_act_map = {
torch.nn.ReLU:
(qnn.QuantReLU, {
**quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}),
torch.nn.ReLU6:
(qnn.QuantReLU, {
**quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}),
torch.nn.Sigmoid: (
qnn.QuantSigmoid, {
**quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}),}
torch.nn.ReLU: (qnn.QuantReLU, {
**unsigned_quant_act_kwargs}),
torch.nn.ReLU6: (qnn.QuantReLU, {
**unsigned_quant_act_kwargs}),
torch.nn.Sigmoid: (qnn.QuantSigmoid, {
**unsigned_quant_act_kwargs}),}
quant_identity_map = {
'signed': (qnn.QuantIdentity, {
**quant_act_kwargs}),
'unsigned': (
qnn.QuantIdentity, {
**quant_act_kwargs, 'signed': not uint_sym_act_for_unsigned_values}),}
'unsigned': (qnn.QuantIdentity, {
**unsigned_quant_act_kwargs}),}
quant_layerwise_layer_map = {
torch.nn.Linear: (qnn.QuantLinear, layerwise_quant_wbiol_kwargs),
torch.nn.MultiheadAttention: (qnn.QuantMultiheadAttention, layerwise_quant_mha_kwargs),
Expand Down

0 comments on commit d21ae11

Please sign in to comment.