diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 846d4f290..5af432af7 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -35,7 +35,7 @@ def __init__( dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, padding_mode: str = 'zeros', - bias: bool = True, + bias: Optional[bool] = True, weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, input_quant: Optional[ActQuantType] = None, @@ -129,7 +129,7 @@ def __init__( dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, padding_mode: str = 'zeros', - bias: bool = True, + bias: Optional[bool] = True, weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, input_quant: Optional[ActQuantType] = None, diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index c5dbd52b9..64fbe8eb6 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -38,7 +38,7 @@ def __init__( dilation: int = 1, groups: int = 1, padding_mode: str = 'zeros', - bias: bool = True, + bias: Optional[bool] = True, weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, input_quant: Optional[ActQuantType] = None, @@ -135,7 +135,7 @@ def __init__( dilation: Union[int, Tuple[int]] = 1, groups: int = 1, padding_mode: str = 'zeros', - bias: bool = True, + bias: Optional[bool] = True, weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, input_quant: Optional[ActQuantType] = None, diff --git a/src/brevitas/nn/quant_linear.py b/src/brevitas/nn/quant_linear.py index 576529a32..46f3191b9 100644 --- a/src/brevitas/nn/quant_linear.py +++ b/src/brevitas/nn/quant_linear.py @@ -27,7 +27,7 @@ def __init__( self, in_features: int, out_features: int, - bias: bool, + bias: Optional[bool] = True, weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, input_quant: Optional[ActQuantType] = None, diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 396a4f6ef..642c1b2d1 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -882,7 +882,7 @@ def __init__( hidden_size: int, num_layers: int = 1, nonlinearity: str = 'tanh', - bias: bool = True, + bias: Optional[bool] = True, batch_first: bool = False, bidirectional: bool = False, weight_quant=Int8WeightPerTensorFloat, @@ -921,7 +921,7 @@ def __init__( input_size: int, hidden_size: int, num_layers: int = 1, - bias: bool = True, + bias: Optional[bool] = True, batch_first: bool = False, bidirectional: bool = False, weight_quant=Int8WeightPerTensorFloat,