Skip to content

Commit

Permalink
Fix (nn): make bias for QuantLayer optional (#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob authored Feb 15, 2024
1 parent 2d76aa3 commit 2369645
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
bias: Optional[bool] = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
bias: Optional[bool] = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
dilation: int = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
bias: Optional[bool] = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
bias: Optional[bool] = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/nn/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
bias: Optional[bool] = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def __init__(
hidden_size: int,
num_layers: int = 1,
nonlinearity: str = 'tanh',
bias: bool = True,
bias: Optional[bool] = True,
batch_first: bool = False,
bidirectional: bool = False,
weight_quant=Int8WeightPerTensorFloat,
Expand Down Expand Up @@ -921,7 +921,7 @@ def __init__(
input_size: int,
hidden_size: int,
num_layers: int = 1,
bias: bool = True,
bias: Optional[bool] = True,
batch_first: bool = False,
bidirectional: bool = False,
weight_quant=Int8WeightPerTensorFloat,
Expand Down

0 comments on commit 2369645

Please sign in to comment.