Skip to content

Commit

Permalink
Fix (examples): adding bias_quant to final linear layer in resnet18 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Nov 6, 2023
1 parent 513ab4d commit e410ff3
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/brevitas_examples/bnn_pynq/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerChannelFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import IntBias
from brevitas.quant import TruncTo8bit
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
act_bit_width=8,
weight_bit_width=8,
round_average_pool=False,
last_layer_bias_quant=IntBias,
weight_quant=Int8WeightPerChannelFloat,
first_layer_weight_quant=Int8WeightPerChannelFloat,
last_layer_weight_quant=Int8WeightPerTensorFloat):
Expand Down Expand Up @@ -163,6 +165,7 @@ def __init__(
num_classes,
weight_bit_width=8,
bias=True,
bias_quant=last_layer_bias_quant,
weight_quant=last_layer_weight_quant)

for m in self.modules():
Expand Down Expand Up @@ -224,7 +227,8 @@ def quant_resnet18(cfg) -> QuantResNet:
act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH')
num_classes = cfg.getint('MODEL', 'NUM_CLASSES')
model = QuantResNet(
QuantBasicBlock, [2, 2, 2, 2],
block_impl=QuantBasicBlock,
num_blocks=[2, 2, 2, 2],
num_classes=num_classes,
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width)
Expand Down

0 comments on commit e410ff3

Please sign in to comment.