diff --git a/docs/.buildinfo b/docs/.buildinfo index 028bec665..ae0de1084 100644 --- a/docs/.buildinfo +++ b/docs/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: d9a08019dec6882195fa0ef4f685a5cd +config: 029b3c75722303c51813ef033f4d07df tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/_modules/brevitas/core/bit_width/const.html b/docs/_modules/brevitas/core/bit_width/const.html index 63cecaf10..476d7ce41 100644 --- a/docs/_modules/brevitas/core/bit_width/const.html +++ b/docs/_modules/brevitas/core/bit_width/const.html @@ -8,7 +8,7 @@ - brevitas.core.bit_width.const — Brevitas 0.10.2 documentation + brevitas.core.bit_width.const — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + diff --git a/docs/_modules/brevitas/core/bit_width/parameter.html b/docs/_modules/brevitas/core/bit_width/parameter.html index 7b8942cba..d0e990a63 100644 --- a/docs/_modules/brevitas/core/bit_width/parameter.html +++ b/docs/_modules/brevitas/core/bit_width/parameter.html @@ -8,7 +8,7 @@ - brevitas.core.bit_width.parameter — Brevitas 0.10.2 documentation + brevitas.core.bit_width.parameter — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + diff --git a/docs/_modules/brevitas/core/function_wrapper/clamp.html b/docs/_modules/brevitas/core/function_wrapper/clamp.html index 76a755ef8..6b6b1c372 100644 --- a/docs/_modules/brevitas/core/function_wrapper/clamp.html +++ b/docs/_modules/brevitas/core/function_wrapper/clamp.html @@ -8,7 +8,7 @@ - brevitas.core.function_wrapper.clamp — Brevitas 0.10.2 documentation + brevitas.core.function_wrapper.clamp — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + @@ -414,12 +414,16 @@

Source code for brevitas.core.function_wrapper.clamp

""" ScriptModule wrappers for various variants of clamping. """ +from typing import Optional, Tuple import torch from torch import Tensor +from torch.nn import Module import brevitas +from brevitas.core.utils import StatelessBuffer from brevitas.function import tensor_clamp +from brevitas.function.ops import max_float
[docs]class TensorClamp(brevitas.jit.ScriptModule): @@ -483,6 +487,90 @@

Source code for brevitas.core.function_wrapper.clamp

[docs] @brevitas.jit.script_method def forward(self, x: Tensor): return x.clamp_min(self.min_val)
+ + +
[docs]class FloatClamp(brevitas.jit.ScriptModule): + """" + ScriptModule for clamping minifloat formats to their inf/NaN implementations. + + Currently, inf/NaN codes have to be encoded through the mantissa. + I.e. setting inf to 1101.111 (E4M3) is not a valid code. + """ + + __constants__ = ['saturating', 'inf_values', 'nan_values', 'signed'] + + def __init__( + self, + tensor_clamp_impl: Module, + signed: bool, + inf_values: Optional[Tuple[str]] = None, + nan_values: Optional[Tuple[str]] = None, + max_available_float: Optional[Tensor] = None, + saturating: bool = True, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None) -> None: + super(FloatClamp, self).__init__() + + self.tensor_clamp_impl = tensor_clamp_impl + self.saturating = saturating + self.inf_values = inf_values + self.nan_values = nan_values + self.signed = signed + + if max_available_float: + max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) + self.max_available_float = StatelessBuffer(max_available_float) + else: + self.max_available_float = None + +
[docs] def inf_nan_clamp(self, x, inf_mask, p_max_val_mask, n_max_val_mask): + + # if non-saturating, we need to map values greater than max_val to nan or inf + if self.inf_values is not None: + # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf + x[p_max_val_mask] = torch.tensor(float('inf')) + x[n_max_val_mask] = torch.tensor(float('-inf')) + elif self.nan_values is not None: + # no inf values, so we need to map them to NaN + full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) + x[full_max_val_mask] = torch.tensor(float('nan')) + + # we also map the inf values to NaN in this case + x[inf_mask] = torch.tensor(float('nan')) + else: + raise RuntimeError( + "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" + ) + return x
+ +
[docs] def saturating_clamp(self, x, max_value, min_value): + return self.tensor_clamp_impl(x, min_val=min_value, max_val=max_value)
+ +
[docs] @brevitas.jit.script_method + def forward( + self, + x: Tensor, + exponent_bit_width: Tensor, + mantissa_bit_width: Tensor, + exponent_bias: Tensor): + + max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) + max_value = max_value if self.max_available_float is None else torch.min( + max_value, self.max_available_float()) + min_value = torch.tensor(0.) if not self.signed else -max_value + + # Compute masks + inf_mask = x.isinf() + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value + + # first clamp everything to +- max_value, basically the saturating case + x = self.saturating_clamp(x, max_value, min_value) + + if not self.saturating: + x = self.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) + + return x, self.saturating, self.inf_values, self.nan_values
diff --git a/docs/_modules/brevitas/core/function_wrapper/misc.html b/docs/_modules/brevitas/core/function_wrapper/misc.html index 081a0b671..cdc8f4546 100644 --- a/docs/_modules/brevitas/core/function_wrapper/misc.html +++ b/docs/_modules/brevitas/core/function_wrapper/misc.html @@ -8,7 +8,7 @@ - brevitas.core.function_wrapper.misc — Brevitas 0.10.2 documentation + brevitas.core.function_wrapper.misc — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + diff --git a/docs/_modules/brevitas/core/function_wrapper/ops_ste.html b/docs/_modules/brevitas/core/function_wrapper/ops_ste.html index cd18f66aa..7a8cab9da 100644 --- a/docs/_modules/brevitas/core/function_wrapper/ops_ste.html +++ b/docs/_modules/brevitas/core/function_wrapper/ops_ste.html @@ -8,7 +8,7 @@ - brevitas.core.function_wrapper.ops_ste — Brevitas 0.10.2 documentation + brevitas.core.function_wrapper.ops_ste — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + diff --git a/docs/_modules/brevitas/core/function_wrapper/shape.html b/docs/_modules/brevitas/core/function_wrapper/shape.html index d4ca6da1c..37f7c4f83 100644 --- a/docs/_modules/brevitas/core/function_wrapper/shape.html +++ b/docs/_modules/brevitas/core/function_wrapper/shape.html @@ -8,7 +8,7 @@ - brevitas.core.function_wrapper.shape — Brevitas 0.10.2 documentation + brevitas.core.function_wrapper.shape — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + @@ -426,6 +426,7 @@

Source code for brevitas.core.function_wrapper.shape

from brevitas.function.shape import over_output_channels from brevitas.function.shape import over_output_features from brevitas.function.shape import over_tensor +from brevitas.utils.torch_utils import padding
[docs]class PermuteDims(brevitas.jit.ScriptModule): @@ -563,6 +564,54 @@

Source code for brevitas.core.function_wrapper.shape

return y.reshape(shape)
+
[docs]class OverSubChannelBlockView(brevitas.jit.ScriptModule): + __constants__ = ['expanded_groupwise_shape', 'group_size', 'group_dim'] + + def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None: + super(OverSubChannelBlockView, self).__init__() + self.expanded_groupwise_shape = expanded_groupwise_shape + self.group_dim = group_dim + self.group_size = group_size + +
[docs] @brevitas.jit.script_method + def forward(self, x: torch.Tensor): + # This one is a bit tricky but we could end up here: + # - If we quantize the zero point, which will already have expanded shape matching the scale (although no padding, but we don't need the padding) + # - Groupwise HQO quantization, where weight will already have been padded and expanded + if len(x.shape) == len(self.expanded_groupwise_shape): + return x + y = torch.nn.functional.pad( + x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.) + y = y.view(self.expanded_groupwise_shape) + return y
+ + +
[docs]class DynamicOverSubChannelBlockView(brevitas.jit.ScriptModule): + __constants__ = ['group_size', 'group_dim'] + + def __init__(self, group_size, group_dim) -> None: + super(DynamicOverSubChannelBlockView, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + +
[docs] @brevitas.jit.script_method + def forward(self, x): + + tensor_shape = x.shape + tensor_shape_list = list(tensor_shape) + pad = padding(x, self.group_size, self.group_dim) + + x = torch.nn.functional.pad(x, pad, mode='constant', value=0.) + + tensor_shape = x.shape + tensor_shape_list = list(tensor_shape) + tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) + block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list.insert(block_dim, self.group_size) + x = x.view(tensor_shape_list) + return x
+ +
[docs]class StatsInputViewShapeImpl(object): """ Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. @@ -572,7 +621,9 @@

Source code for brevitas.core.function_wrapper.shape

OVER_OUTPUT_CHANNELS = OverOutputChannelView OVER_BATCH_OVER_TENSOR = OverBatchOverTensorView OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView - OVER_OUTPUT_FEATURES = OverOutputFeaturesView
+ OVER_OUTPUT_FEATURES = OverOutputFeaturesView + OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView + DYNAMIC_OVER_SUBCHANNEL_BLOCK = DynamicOverSubChannelBlockView
diff --git a/docs/_modules/brevitas/core/quant/binary.html b/docs/_modules/brevitas/core/quant/binary.html index 0c1584966..5ba1e04b4 100644 --- a/docs/_modules/brevitas/core/quant/binary.html +++ b/docs/_modules/brevitas/core/quant/binary.html @@ -8,7 +8,7 @@ - brevitas.core.quant.binary — Brevitas 0.10.2 documentation + brevitas.core.quant.binary — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + @@ -458,8 +458,9 @@

Source code for brevitas.core.quant.binary

         Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
     """
 
-    def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):
+    def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
         super(BinaryQuant, self).__init__()
+        assert signed, "Unsigned binary quant not supported"
         self.scaling_impl = scaling_impl
         self.bit_width = BitWidthConst(1)
         self.zero_point = StatelessBuffer(torch.tensor(0.0))
diff --git a/docs/_modules/brevitas/core/quant/delay.html b/docs/_modules/brevitas/core/quant/delay.html
index 347941eea..a401098b0 100644
--- a/docs/_modules/brevitas/core/quant/delay.html
+++ b/docs/_modules/brevitas/core/quant/delay.html
@@ -8,7 +8,7 @@
   
     
     
-    brevitas.core.quant.delay — Brevitas 0.10.2 documentation
+    brevitas.core.quant.delay — Brevitas 0.11.0 documentation
   
   
   
@@ -123,8 +123,8 @@
       
     
     
-    Brevitas 0.10.2 documentation - Home
-    
+    Brevitas 0.11.0 documentation - Home
+    
   
   
 
diff --git a/docs/_modules/brevitas/core/quant/int.html b/docs/_modules/brevitas/core/quant/int.html index 1eb00d929..193997759 100644 --- a/docs/_modules/brevitas/core/quant/int.html +++ b/docs/_modules/brevitas/core/quant/int.html @@ -8,7 +8,7 @@ - brevitas.core.quant.int — Brevitas 0.10.2 documentation + brevitas.core.quant.int — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + @@ -555,15 +555,18 @@

Source code for brevitas.core.quant.int

         self.int_scaling_impl = int_scaling_impl
         self.zero_point_impl = zero_point_impl
         self.msb_clamp_bit_width_impl = bit_width_impl
+        self.observer_only = brevitas.jit.Attribute(False, bool)
 
 
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: bit_width = self.msb_clamp_bit_width_impl() - threshold = self.scaling_impl(x) int_threshold = self.int_scaling_impl(bit_width) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.int_quant(scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width
@@ -586,6 +589,7 @@

Source code for brevitas.core.quant.int

         self.pre_zero_point_impl = pre_zero_point_impl
         self.zero_point_impl = zero_point_impl
         self.msb_clamp_bit_width_impl = bit_width_impl
+        self.observer_only = brevitas.jit.Attribute(False, bool)
 
 
[docs] @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -594,10 +598,12 @@

Source code for brevitas.core.quant.int

         pre_threshold = self.pre_scaling_impl(x)
         pre_scale = pre_threshold / int_threshold
         pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
-        threshold = self.scaling_impl(x)
-        scale = threshold / int_threshold
+        scale = self.scaling_impl(x, int_threshold)
         zero_point = self.zero_point_impl(x, scale, bit_width)
-        y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
+        if self.observer_only:
+            y = x
+        else:
+            y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
         return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
@@ -660,10 +666,12 @@

Source code for brevitas.core.quant.int

         pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed)
         pre_scale = pre_threshold / int_threshold
         pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
-        threshold = self.scaling_impl(x)
-        scale = threshold / int_threshold
+        scale = self.scaling_impl(x, int_threshold)
         zero_point = self.zero_point_impl(x, scale, bit_width)
-        y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
+        if self.observer_only:
+            y = x
+        else:
+            y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
         return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
diff --git a/docs/_modules/brevitas/core/quant/int_base.html b/docs/_modules/brevitas/core/quant/int_base.html index d4223cf8b..8c84f3583 100644 --- a/docs/_modules/brevitas/core/quant/int_base.html +++ b/docs/_modules/brevitas/core/quant/int_base.html @@ -8,7 +8,7 @@ - brevitas.core.quant.int_base — Brevitas 0.10.2 documentation + brevitas.core.quant.int_base — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home + @@ -461,6 +461,7 @@

Source code for brevitas.core.quant.int_base

self,
             narrow_range: bool,
             signed: bool,
+            input_view_impl: Module,
             float_to_int_impl: Module = RoundSte(),
             tensor_clamp_impl: Module = TensorClamp(),
             quant_delay_steps: int = 0):
@@ -470,9 +471,11 @@ 

Source code for brevitas.core.quant.int_base

self.signed = signed
         self.narrow_range = narrow_range
         self.delay_wrapper = DelayWrapper(quant_delay_steps)
+        self.input_view_impl = input_view_impl
 
 
[docs] @brevitas.jit.script_method def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: + x = self.input_view_impl(x) y = x / scale y = y + zero_point min_int_val = self.min_int(bit_width) @@ -534,6 +537,7 @@

Source code for brevitas.core.quant.int_base

self,
             narrow_range: bool,
             signed: bool,
+            input_view_impl: Module,
             float_to_int_impl: Module = RoundSte(),
             tensor_clamp_impl: Module = TensorClamp(),
             quant_delay_steps: int = 0):
@@ -543,11 +547,13 @@ 

Source code for brevitas.core.quant.int_base

self.signed = signed
         self.narrow_range = narrow_range
         self.delay_wrapper = DelayWrapper(quant_delay_steps)
+        self.input_view_impl = input_view_impl
 
 
[docs] @brevitas.jit.script_method def to_int( self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: + x = self.input_view_impl(x) y = x / pre_scale y = y + pre_zero_point min_int_val = self.min_int(bit_width) diff --git a/docs/_modules/brevitas/core/quant/ternary.html b/docs/_modules/brevitas/core/quant/ternary.html index 9ac3f7859..b77615c2c 100644 --- a/docs/_modules/brevitas/core/quant/ternary.html +++ b/docs/_modules/brevitas/core/quant/ternary.html @@ -8,7 +8,7 @@ - brevitas.core.quant.ternary — Brevitas 0.10.2 documentation + brevitas.core.quant.ternary — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/_modules/brevitas/core/restrict_val.html b/docs/_modules/brevitas/core/restrict_val.html index cdf92c705..724e23eaf 100644 --- a/docs/_modules/brevitas/core/restrict_val.html +++ b/docs/_modules/brevitas/core/restrict_val.html @@ -8,7 +8,7 @@ - brevitas.core.restrict_val — Brevitas 0.10.2 documentation + brevitas.core.restrict_val — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
@@ -446,7 +446,7 @@

Source code for brevitas.core.restrict_val

             self.restrict_value_impl = Identity()
 
     @brevitas.jit.script_method
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: Tensor):
         x = self.restrict_value_impl(x)
         x = self.clamp_min_ste(x)
         return x
@@ -462,7 +462,7 @@ 

Source code for brevitas.core.restrict_val

             self.restrict_value_impl = Identity()
 
     @brevitas.jit.script_method
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: Tensor):
         x = self.restrict_value_impl(x)
         return x
 
@@ -478,7 +478,7 @@ 

Source code for brevitas.core.restrict_val

         self.min_val = scaling_min_val
 
     @brevitas.jit.script_method
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: Tensor):
         x = self.clamp_min_ste(x)
         return x
 
@@ -500,8 +500,11 @@ 

Source code for brevitas.core.restrict_val

 
[docs] def restrict_init_inplace_module(self): return Identity()
+
[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x / threshold
+
[docs] @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: return x
@@ -514,7 +517,7 @@

Source code for brevitas.core.restrict_val

 
[docs] def restrict_init_float(self, x: float): return math.log2(x)
-
[docs] def restrict_init_tensor(self, x: torch.Tensor): +
[docs] def restrict_init_tensor(self, x: Tensor): return torch.log2(x)
[docs] def restrict_init_module(self): @@ -523,8 +526,11 @@

Source code for brevitas.core.restrict_val

 
[docs] def restrict_init_inplace_module(self): return InplaceLogTwo()
+
[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x - threshold
+
[docs] @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.power_of_two(x) return x
@@ -538,7 +544,7 @@

Source code for brevitas.core.restrict_val

 
[docs] def restrict_init_float(self, x: float): return x
-
[docs] def restrict_init_tensor(self, x: torch.Tensor): +
[docs] def restrict_init_tensor(self, x: Tensor): return x
[docs] def restrict_init_module(self): @@ -547,8 +553,11 @@

Source code for brevitas.core.restrict_val

 
[docs] def restrict_init_inplace_module(self): return Identity()
+
[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x / threshold
+
[docs] @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.float_to_int_impl(x) return x
@@ -563,7 +572,7 @@

Source code for brevitas.core.restrict_val

 
[docs] def restrict_init_float(self, x: float): return math.log2(x)
-
[docs] def restrict_init_tensor(self, x: torch.Tensor): +
[docs] def restrict_init_tensor(self, x: Tensor): return torch.log2(x)
[docs] def restrict_init_module(self): @@ -572,8 +581,11 @@

Source code for brevitas.core.restrict_val

 
[docs] def restrict_init_inplace_module(self): return InplaceLogTwo()
+
[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x - threshold
+
[docs] @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x
diff --git a/docs/_modules/brevitas/core/scaling/int_scaling.html b/docs/_modules/brevitas/core/scaling/int_scaling.html index ab0663819..3dfd7c93d 100644 --- a/docs/_modules/brevitas/core/scaling/int_scaling.html +++ b/docs/_modules/brevitas/core/scaling/int_scaling.html @@ -8,7 +8,7 @@ - brevitas.core.scaling.int_scaling — Brevitas 0.10.2 documentation + brevitas.core.scaling.int_scaling — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/_modules/brevitas/core/scaling/runtime.html b/docs/_modules/brevitas/core/scaling/runtime.html index 566a80cc3..31e87e7bc 100644 --- a/docs/_modules/brevitas/core/scaling/runtime.html +++ b/docs/_modules/brevitas/core/scaling/runtime.html @@ -8,7 +8,7 @@ - brevitas.core.scaling.runtime — Brevitas 0.10.2 documentation + brevitas.core.scaling.runtime — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
@@ -421,6 +421,7 @@

Source code for brevitas.core.scaling.runtime

import brevitas.config as config from brevitas.core.function_wrapper import Identity from brevitas.core.restrict_val import _RestrictClampValue +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.stats import _ParameterListStats from brevitas.core.stats import _RuntimeStats from brevitas.core.stats import DEFAULT_MOMENTUM @@ -437,8 +438,8 @@

Source code for brevitas.core.scaling.runtime

scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], + restrict_scaling_impl: Module = FloatRestrictValue(), affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, @@ -461,9 +462,12 @@

Source code for brevitas.core.scaling.runtime

device)
[docs] @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: + def forward( + self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.parameter_list_stats() - return self.stats_scaling_impl(stats)
+ if threshold is None: + threshold = torch.ones(1).type_as(stats) + return self.stats_scaling_impl(stats, threshold)
class _StatsScaling(brevitas.jit.ScriptModule): @@ -488,10 +492,16 @@

Source code for brevitas.core.scaling.runtime

self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_scaling_impl = restrict_scaling_impl @brevitas.jit.script_method - def forward(self, stats: torch.Tensor) -> torch.Tensor: + def forward( + self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats) + threshold = self.restrict_scaling_pre(threshold) stats = self.restrict_scaling_pre(stats) + stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats @@ -503,10 +513,10 @@

Source code for brevitas.core.scaling.runtime

self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], affine_rescaling: bool = False, affine_shift_scale: bool = False, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -530,9 +540,9 @@

Source code for brevitas.core.scaling.runtime

device)
[docs] @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.runtime_stats(x) - return self.stats_scaling_impl(stats)
+ return self.stats_scaling_impl(stats, threshold)
class _AffineRescaling(brevitas.jit.ScriptModule): @@ -568,6 +578,38 @@

Source code for brevitas.core.scaling.runtime

missing_keys.remove(affine_weight_key) if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys: missing_keys.remove(affine_bias_key) + + +
[docs]class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): + + def __init__( + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + scaling_stats_impl: Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Module = FloatRestrictValue()) -> None: + super(RuntimeDynamicGroupStatsScaling, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + self.scaling_stats_impl = scaling_stats_impl + self.scaling_min_val = scaling_min_val + self.input_view_impl = input_view_impl + self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + +
[docs] @brevitas.jit.script_method + def forward( + self, + stats_input: torch.Tensor, + threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) + stats_input_reshaped = self.input_view_impl(stats_input) + out = self.scaling_stats_impl(stats_input_reshaped) / threshold + # Scaling min val + out = self.restrict_clamp_scaling(out) + return out
diff --git a/docs/_modules/brevitas/core/scaling/standalone.html b/docs/_modules/brevitas/core/scaling/standalone.html index 54ec5390e..2948d3552 100644 --- a/docs/_modules/brevitas/core/scaling/standalone.html +++ b/docs/_modules/brevitas/core/scaling/standalone.html @@ -8,7 +8,7 @@ - brevitas.core.scaling.standalone — Brevitas 0.10.2 documentation + brevitas.core.scaling.standalone — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
@@ -425,6 +425,7 @@

Source code for brevitas.core.scaling.standalone

from brevitas.core.restrict_val import _ClampValue from brevitas.core.restrict_val import _RestrictClampValue from brevitas.core.restrict_val import _RestrictValue +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.scaling.runtime import _StatsScaling from brevitas.core.stats import _ParameterListStats from brevitas.core.stats import _Stats @@ -470,7 +471,7 @@

Source code for brevitas.core.scaling.standalone

def __init__( self, scaling_init: Union[float, Tensor], - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -478,18 +479,23 @@

Source code for brevitas.core.scaling.standalone

self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) if isinstance(scaling_init, Tensor): scaling_init = scaling_init.to(device=device, dtype=dtype) - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(scaling_init.detach()) else: - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) + scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device))
[docs] @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = self.value() - restricted_value = self.restrict_clamp_scaling(value) + def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + restricted_value = self.restrict_clamp_scaling(self.value()) + restricted_value = restricted_value / threshold return restricted_value
@@ -536,7 +542,7 @@

Source code for brevitas.core.scaling.standalone

self, scaling_init: Union[float, Tensor], scaling_shape: Optional[Tuple[int, ...]] = None, - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -551,17 +557,24 @@

Source code for brevitas.core.scaling.standalone

scaling_init = scaling_init.detach() else: scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + + scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() + if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
[docs] @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: + def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) - return value
+ return value / threshold
def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -588,8 +601,8 @@

Source code for brevitas.core.scaling.standalone

scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -600,30 +613,38 @@

Source code for brevitas.core.scaling.standalone

scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) + self.restrict_scaling_impl = restrict_scaling_impl self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) - if restrict_scaling_impl is not None: - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - else: - self.restrict_inplace_preprocess = Identity() + self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))
[docs] @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: + def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(ignored) + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.init_done: - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + threshold = self.restrict_inplace_preprocess(threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) return value else: stats = self.parameter_list_stats() # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) stats = self.restrict_inplace_preprocess(stats) + threshold = self.restrict_inplace_preprocess(threshold) inplace_tensor_mul(self.value.detach(), stats) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) self.init_done = True return value
@@ -631,7 +652,7 @@

Source code for brevitas.core.scaling.standalone

output_dict = super(ParameterFromStatsFromParameterScaling, self).state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) # Avoid saving the init value - if not self.init_done: + if not self.init_done and not config._FULL_STATE_DICT: del output_dict[prefix + 'value'] return output_dict
@@ -700,7 +721,7 @@

Source code for brevitas.core.scaling.standalone

scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -715,19 +736,19 @@

Source code for brevitas.core.scaling.standalone

scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) + self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( False, bool) # required to support MSE eval or variants - if restrict_scaling_impl is not None: - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() - else: - self.restrict_inplace_preprocess = Identity() - self.restrict_preprocess = Identity() + self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
[docs] @brevitas.jit.script_method - def training_forward(self, stats_input: Tensor) -> Tensor: + def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependent on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -737,32 +758,41 @@

Source code for brevitas.core.scaling.standalone

new_counter = self.counter + 1 # Whenever we are in local loss mode, we don't update the counter nor the buffer if self.local_loss_mode: - return abs_binary_sign_grad(clamped_stats) + # Local loss mode, we early exit and divide by threshold + return abs_binary_sign_grad(clamped_stats / threshold) if self.counter == 0: inplace_tensor_mul(self.buffer, clamped_stats.detach()) else: inplace_momentum_update( self.buffer, clamped_stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter - return abs_binary_sign_grad(clamped_stats) + return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) + threshold = self.restrict_preprocess(threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value)))
+ threshold = self.restrict_preprocess(threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
[docs] @brevitas.jit.script_method - def forward(self, stats_input: Tensor) -> Tensor: + def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) if self.training: - return self.training_forward(stats_input) + # Threshold division handled inside the training_forward + return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer + out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.value + threshold = self.restrict_preprocess(threshold) + out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out
@@ -772,7 +802,7 @@

Source code for brevitas.core.scaling.standalone

# Avoid saving the buffer del output_dict[prefix + 'buffer'] # Avoid saving the init value - if self.counter == 0: + if self.counter == 0 and not config._FULL_STATE_DICT: del output_dict[prefix + 'value'] # Save buffer into value for any non-zero number of collection steps elif self.counter <= self.collect_stats_steps: diff --git a/docs/_modules/brevitas/core/stats/stats_op.html b/docs/_modules/brevitas/core/stats/stats_op.html index 5d95da826..18dffb1fc 100644 --- a/docs/_modules/brevitas/core/stats/stats_op.html +++ b/docs/_modules/brevitas/core/stats/stats_op.html @@ -8,7 +8,7 @@ - brevitas.core.stats.stats_op — Brevitas 0.10.2 documentation + brevitas.core.stats.stats_op — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
@@ -420,8 +420,11 @@

Source code for brevitas.core.stats.stats_op

import brevitas
 from brevitas import config
+from brevitas.core.function_wrapper.misc import Identity
+from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte
 from brevitas.core.utils import StatelessBuffer
 from brevitas.function.ops import max_int
+from brevitas.quant_tensor import _unpack_quant_tensor
 # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
 from brevitas.utils.torch_utils import kthvalue
 
@@ -849,6 +852,19 @@ 

Source code for brevitas.core.stats.stats_op

m.local_loss_mode = enabled
 
 
+def _set_observer_mode(module, enabled, previous_observer_mode):
+    for m in module.modules():
+        if hasattr(m, 'observer_only'):
+            previous_observer_mode[m] = m.observer_only
+            m.observer_only = enabled
+
+
+def _restore_observer_mode(module, previous_observer_mode):
+    for m in module.modules():
+        if hasattr(m, 'observer_only'):
+            m.observer_only = previous_observer_mode[m]
+
+
 
[docs]class MSE(torch.nn.Module): # References: # https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py @@ -866,7 +882,12 @@

Source code for brevitas.core.stats.stats_op

self.mse_init_op = mse_init_op
         self.input_view_shape_impl = inner_stats_input_view_shape_impl
         self.proxy_forward = proxy_module.forward
+        self.previous_observer_mode = dict()
         self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
+        self.set_observer_mode = lambda enabled: _set_observer_mode(
+            proxy_module, enabled, self.previous_observer_mode)
+        self.restore_observer_mode = lambda: _restore_observer_mode(
+            proxy_module, self.previous_observer_mode)
         self.internal_candidate = None
         self.num = mse_iters
         self.search_method = mse_search_method
@@ -887,11 +908,12 @@ 

Source code for brevitas.core.stats.stats_op

self.internal_candidate = candidate
         # Set to local_loss_mode before calling the proxy
         self.set_local_loss_mode(True)
+        self.set_observer_mode(False)
         quant_value = self.proxy_forward(x)
-        if isinstance(quant_value, tuple):
-            quant_value = quant_value[0]
+        quant_value = _unpack_quant_tensor(quant_value)
         loss = self.mse_loss_fn(x, quant_value)
         self.set_local_loss_mode(False)
+        self.restore_observer_mode()
         return loss
+ + +
[docs]class HalfQuadraticOptimizerScale(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + hqo_init_op_scale, + keepdim: bool, + inner_stats_input_view_shape_impl: torch.nn.Module, + scaling_min_val: Optional[float] = None, + stats_reduce_dim: Optional[int] = None, + int_scaling_impl=None, + bit_width_impl=None, + hqo_beta_scale: float = 1e5, + hqo_kappa_scale: float = 1.01, + hqo_lp_norm_scale: float = .7, + hqo_iters_scale: int = 1000): + super(HalfQuadraticOptimizerScale, self).__init__() + self.hqo_init_op = hqo_init_op_scale + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) + self.internal_candidate = None + self.hqo_iters = hqo_iters_scale + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + + self.beta = hqo_beta_scale + self.kappa = hqo_kappa_scale + self.lp_norm = hqo_lp_norm_scale + + self.int_scaling_impl = int_scaling_impl + self.msb_clamp_bit_width_impl = bit_width_impl + if scaling_min_val is not None and scaling_min_val != 0: + self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) + else: + self.clamp_min_ste = Identity() + self.keepdim = keepdim + + + +
[docs] def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.hqo_init_op(x_view).detach() + best_candidate = self.parameter_search(init, x_view) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate
+ +
[docs] def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.hqo_init_op(x).detach() + return self.internal_candidate
+ + +
[docs]class HalfQuadraticOptimizerZeroPoint(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + keepdim: bool, + hqo_init_op_zp: torch.nn.Module, + inner_stats_input_view_shape_impl: torch.nn.Module, + stats_reduce_dim: Optional[int] = None, + hqo_beta_zp: float = 1e0, + hqo_kappa_zp: float = 1.01, + hqo_lp_norm_zp: float = .5, + hqo_iters_zp: int = 1000): + super(HalfQuadraticOptimizerZeroPoint, self).__init__() + self.hqo_init_op_zp = hqo_init_op_zp + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) + self.internal_candidate = None + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + self.beta = hqo_beta_zp + self.kappa = hqo_kappa_zp + self.lp_norm = hqo_lp_norm_zp + self.hqo_iters = hqo_iters_zp + self.keepdim = keepdim + + + +
[docs] def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.hqo_init_op_zp(x_view).detach() + + best_candidate = self.parameter_search(init, x) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate
+ +
[docs] def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.hqo_init_op_zp(x).detach() + return self.internal_candidate
+ + +
[docs]def masked_median(x, mask, dim=None, keepdim=False): + """Compute the median of tensor x along dim, ignoring values where mask is False. + x and mask need to be broadcastable. + + Args: + x (Tensor): Tensor to compute median of. + mask (BoolTensor): Same shape as x with True where x is valid and False + where x should be masked. Mask should not be all False in any column of + dimension dim to avoid NaNs from zero division. + dim (int, optional): Dimension to take median of. Defaults to 0. + + Returns: + Tensor: Same shape as x, except dimension dim reduced. + """ + # uncomment this assert for safety but might impact performance + # assert ( + # mask.sum(dim=dim).ne(0).all() + # ), "mask should not be all False in any column, causes zero division" + x_nan = x.float().masked_fill(~mask, float("nan")) + if dim is None: + x_median = x_nan.nanmedian() + else: + x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim) + return x_median
+ + +# Shrinking operator +
[docs]def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor: + if lp_norm == 1: + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + else: + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1))
diff --git a/docs/_modules/brevitas/core/utils.html b/docs/_modules/brevitas/core/utils.html index 6c90c9602..a7378d4d7 100644 --- a/docs/_modules/brevitas/core/utils.html +++ b/docs/_modules/brevitas/core/utils.html @@ -8,7 +8,7 @@ - brevitas.core.utils — Brevitas 0.10.2 documentation + brevitas.core.utils — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/_modules/brevitas/core/zero_point.html b/docs/_modules/brevitas/core/zero_point.html index 53919717c..986cdc5a3 100644 --- a/docs/_modules/brevitas/core/zero_point.html +++ b/docs/_modules/brevitas/core/zero_point.html @@ -8,7 +8,7 @@ - brevitas.core.zero_point — Brevitas 0.10.2 documentation + brevitas.core.zero_point — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
@@ -691,8 +691,9 @@

Source code for brevitas.core.zero_point

         output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict(
             destination=destination, prefix=prefix, keep_vars=keep_vars)
         # Avoid saving the init value
-        if not self.init_done:
-            del output_dict[prefix + 'value']
+ if not self.init_done and not config._FULL_STATE_DICT: + del output_dict[prefix + 'value'] + return output_dict
def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/docs/_modules/brevitas/function/ops.html b/docs/_modules/brevitas/function/ops.html index c325b40bd..53806a3b6 100644 --- a/docs/_modules/brevitas/function/ops.html +++ b/docs/_modules/brevitas/function/ops.html @@ -8,7 +8,7 @@ - brevitas.function.ops — Brevitas 0.10.2 documentation + brevitas.function.ops — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
@@ -599,7 +599,7 @@

Source code for brevitas.function.ops

     return value
-
[docs]@brevitas.jit.script +
[docs]@brevitas.jit.ignore def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias max_mantissa = torch.sum(( diff --git a/docs/_modules/brevitas/function/ops_ste.html b/docs/_modules/brevitas/function/ops_ste.html index 14b00539e..182924071 100644 --- a/docs/_modules/brevitas/function/ops_ste.html +++ b/docs/_modules/brevitas/function/ops_ste.html @@ -8,7 +8,7 @@ - brevitas.function.ops_ste — Brevitas 0.10.2 documentation + brevitas.function.ops_ste — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/_modules/brevitas/function/shape.html b/docs/_modules/brevitas/function/shape.html index 9b6e94515..cb64c5178 100644 --- a/docs/_modules/brevitas/function/shape.html +++ b/docs/_modules/brevitas/function/shape.html @@ -8,7 +8,7 @@ - brevitas.function.shape — Brevitas 0.10.2 documentation + brevitas.function.shape — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/_modules/brevitas/ops/autograd_ste_ops.html b/docs/_modules/brevitas/ops/autograd_ste_ops.html index 21b701fe8..dbb265da7 100644 --- a/docs/_modules/brevitas/ops/autograd_ste_ops.html +++ b/docs/_modules/brevitas/ops/autograd_ste_ops.html @@ -8,7 +8,7 @@ - brevitas.ops.autograd_ste_ops — Brevitas 0.10.2 documentation + brevitas.ops.autograd_ste_ops — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/_modules/index.html b/docs/_modules/index.html index 5380f2cf8..c1ddf62d6 100644 --- a/docs/_modules/index.html +++ b/docs/_modules/index.html @@ -8,7 +8,7 @@ - Overview: module code — Brevitas 0.10.2 documentation + Overview: module code — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/_static/documentation_options.js b/docs/_static/documentation_options.js index db6d22fe9..9dc22d647 100644 --- a/docs/_static/documentation_options.js +++ b/docs/_static/documentation_options.js @@ -1,6 +1,6 @@ var DOCUMENTATION_OPTIONS = { URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), - VERSION: '0.10.2', + VERSION: '0.11.0', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/_static/pygments.css b/docs/_static/pygments.css index 997797f27..012e6a00a 100644 --- a/docs/_static/pygments.css +++ b/docs/_static/pygments.css @@ -3,77 +3,77 @@ html[data-theme="light"] .highlight td.linenos .normal { color: inherit; backgro html[data-theme="light"] .highlight span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } html[data-theme="light"] .highlight td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } html[data-theme="light"] .highlight span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } -html[data-theme="light"] .highlight .hll { background-color: #7971292e } -html[data-theme="light"] .highlight { background: #fefefe; color: #545454 } -html[data-theme="light"] .highlight .c { color: #797129 } /* Comment */ -html[data-theme="light"] .highlight .err { color: #d91e18 } /* Error */ -html[data-theme="light"] .highlight .k { color: #7928a1 } /* Keyword */ -html[data-theme="light"] .highlight .l { color: #797129 } /* Literal */ -html[data-theme="light"] .highlight .n { color: #545454 } /* Name */ -html[data-theme="light"] .highlight .o { color: #008000 } /* Operator */ -html[data-theme="light"] .highlight .p { color: #545454 } /* Punctuation */ -html[data-theme="light"] .highlight .ch { color: #797129 } /* Comment.Hashbang */ -html[data-theme="light"] .highlight .cm { color: #797129 } /* Comment.Multiline */ -html[data-theme="light"] .highlight .cp { color: #797129 } /* Comment.Preproc */ -html[data-theme="light"] .highlight .cpf { color: #797129 } /* Comment.PreprocFile */ -html[data-theme="light"] .highlight .c1 { color: #797129 } /* Comment.Single */ -html[data-theme="light"] .highlight .cs { color: #797129 } /* Comment.Special */ -html[data-theme="light"] .highlight .gd { color: #007faa } /* Generic.Deleted */ +html[data-theme="light"] .highlight .hll { background-color: #fae4c2 } +html[data-theme="light"] .highlight { background: #fefefe; color: #080808 } +html[data-theme="light"] .highlight .c { color: #515151 } /* Comment */ +html[data-theme="light"] .highlight .err { color: #a12236 } /* Error */ +html[data-theme="light"] .highlight .k { color: #6730c5 } /* Keyword */ +html[data-theme="light"] .highlight .l { color: #7f4707 } /* Literal */ +html[data-theme="light"] .highlight .n { color: #080808 } /* Name */ +html[data-theme="light"] .highlight .o { color: #00622f } /* Operator */ +html[data-theme="light"] .highlight .p { color: #080808 } /* Punctuation */ +html[data-theme="light"] .highlight .ch { color: #515151 } /* Comment.Hashbang */ +html[data-theme="light"] .highlight .cm { color: #515151 } /* Comment.Multiline */ +html[data-theme="light"] .highlight .cp { color: #515151 } /* Comment.Preproc */ +html[data-theme="light"] .highlight .cpf { color: #515151 } /* Comment.PreprocFile */ +html[data-theme="light"] .highlight .c1 { color: #515151 } /* Comment.Single */ +html[data-theme="light"] .highlight .cs { color: #515151 } /* Comment.Special */ +html[data-theme="light"] .highlight .gd { color: #005b82 } /* Generic.Deleted */ html[data-theme="light"] .highlight .ge { font-style: italic } /* Generic.Emph */ -html[data-theme="light"] .highlight .gh { color: #007faa } /* Generic.Heading */ +html[data-theme="light"] .highlight .gh { color: #005b82 } /* Generic.Heading */ html[data-theme="light"] .highlight .gs { font-weight: bold } /* Generic.Strong */ -html[data-theme="light"] .highlight .gu { color: #007faa } /* Generic.Subheading */ -html[data-theme="light"] .highlight .kc { color: #7928a1 } /* Keyword.Constant */ -html[data-theme="light"] .highlight .kd { color: #7928a1 } /* Keyword.Declaration */ -html[data-theme="light"] .highlight .kn { color: #7928a1 } /* Keyword.Namespace */ -html[data-theme="light"] .highlight .kp { color: #7928a1 } /* Keyword.Pseudo */ -html[data-theme="light"] .highlight .kr { color: #7928a1 } /* Keyword.Reserved */ -html[data-theme="light"] .highlight .kt { color: #797129 } /* Keyword.Type */ -html[data-theme="light"] .highlight .ld { color: #797129 } /* Literal.Date */ -html[data-theme="light"] .highlight .m { color: #797129 } /* Literal.Number */ -html[data-theme="light"] .highlight .s { color: #008000 } /* Literal.String */ -html[data-theme="light"] .highlight .na { color: #797129 } /* Name.Attribute */ -html[data-theme="light"] .highlight .nb { color: #797129 } /* Name.Builtin */ -html[data-theme="light"] .highlight .nc { color: #007faa } /* Name.Class */ -html[data-theme="light"] .highlight .no { color: #007faa } /* Name.Constant */ -html[data-theme="light"] .highlight .nd { color: #797129 } /* Name.Decorator */ -html[data-theme="light"] .highlight .ni { color: #008000 } /* Name.Entity */ -html[data-theme="light"] .highlight .ne { color: #7928a1 } /* Name.Exception */ -html[data-theme="light"] .highlight .nf { color: #007faa } /* Name.Function */ -html[data-theme="light"] .highlight .nl { color: #797129 } /* Name.Label */ -html[data-theme="light"] .highlight .nn { color: #545454 } /* Name.Namespace */ -html[data-theme="light"] .highlight .nx { color: #545454 } /* Name.Other */ -html[data-theme="light"] .highlight .py { color: #007faa } /* Name.Property */ -html[data-theme="light"] .highlight .nt { color: #007faa } /* Name.Tag */ -html[data-theme="light"] .highlight .nv { color: #d91e18 } /* Name.Variable */ -html[data-theme="light"] .highlight .ow { color: #7928a1 } /* Operator.Word */ -html[data-theme="light"] .highlight .pm { color: #545454 } /* Punctuation.Marker */ -html[data-theme="light"] .highlight .w { color: #545454 } /* Text.Whitespace */ -html[data-theme="light"] .highlight .mb { color: #797129 } /* Literal.Number.Bin */ -html[data-theme="light"] .highlight .mf { color: #797129 } /* Literal.Number.Float */ -html[data-theme="light"] .highlight .mh { color: #797129 } /* Literal.Number.Hex */ -html[data-theme="light"] .highlight .mi { color: #797129 } /* Literal.Number.Integer */ -html[data-theme="light"] .highlight .mo { color: #797129 } /* Literal.Number.Oct */ -html[data-theme="light"] .highlight .sa { color: #008000 } /* Literal.String.Affix */ -html[data-theme="light"] .highlight .sb { color: #008000 } /* Literal.String.Backtick */ -html[data-theme="light"] .highlight .sc { color: #008000 } /* Literal.String.Char */ -html[data-theme="light"] .highlight .dl { color: #008000 } /* Literal.String.Delimiter */ -html[data-theme="light"] .highlight .sd { color: #008000 } /* Literal.String.Doc */ -html[data-theme="light"] .highlight .s2 { color: #008000 } /* Literal.String.Double */ -html[data-theme="light"] .highlight .se { color: #008000 } /* Literal.String.Escape */ -html[data-theme="light"] .highlight .sh { color: #008000 } /* Literal.String.Heredoc */ -html[data-theme="light"] .highlight .si { color: #008000 } /* Literal.String.Interpol */ -html[data-theme="light"] .highlight .sx { color: #008000 } /* Literal.String.Other */ -html[data-theme="light"] .highlight .sr { color: #d91e18 } /* Literal.String.Regex */ -html[data-theme="light"] .highlight .s1 { color: #008000 } /* Literal.String.Single */ -html[data-theme="light"] .highlight .ss { color: #007faa } /* Literal.String.Symbol */ -html[data-theme="light"] .highlight .bp { color: #797129 } /* Name.Builtin.Pseudo */ -html[data-theme="light"] .highlight .fm { color: #007faa } /* Name.Function.Magic */ -html[data-theme="light"] .highlight .vc { color: #d91e18 } /* Name.Variable.Class */ -html[data-theme="light"] .highlight .vg { color: #d91e18 } /* Name.Variable.Global */ -html[data-theme="light"] .highlight .vi { color: #d91e18 } /* Name.Variable.Instance */ -html[data-theme="light"] .highlight .vm { color: #797129 } /* Name.Variable.Magic */ -html[data-theme="light"] .highlight .il { color: #797129 } /* Literal.Number.Integer.Long */ +html[data-theme="light"] .highlight .gu { color: #005b82 } /* Generic.Subheading */ +html[data-theme="light"] .highlight .kc { color: #6730c5 } /* Keyword.Constant */ +html[data-theme="light"] .highlight .kd { color: #6730c5 } /* Keyword.Declaration */ +html[data-theme="light"] .highlight .kn { color: #6730c5 } /* Keyword.Namespace */ +html[data-theme="light"] .highlight .kp { color: #6730c5 } /* Keyword.Pseudo */ +html[data-theme="light"] .highlight .kr { color: #6730c5 } /* Keyword.Reserved */ +html[data-theme="light"] .highlight .kt { color: #7f4707 } /* Keyword.Type */ +html[data-theme="light"] .highlight .ld { color: #7f4707 } /* Literal.Date */ +html[data-theme="light"] .highlight .m { color: #7f4707 } /* Literal.Number */ +html[data-theme="light"] .highlight .s { color: #00622f } /* Literal.String */ +html[data-theme="light"] .highlight .na { color: #912583 } /* Name.Attribute */ +html[data-theme="light"] .highlight .nb { color: #7f4707 } /* Name.Builtin */ +html[data-theme="light"] .highlight .nc { color: #005b82 } /* Name.Class */ +html[data-theme="light"] .highlight .no { color: #005b82 } /* Name.Constant */ +html[data-theme="light"] .highlight .nd { color: #7f4707 } /* Name.Decorator */ +html[data-theme="light"] .highlight .ni { color: #00622f } /* Name.Entity */ +html[data-theme="light"] .highlight .ne { color: #6730c5 } /* Name.Exception */ +html[data-theme="light"] .highlight .nf { color: #005b82 } /* Name.Function */ +html[data-theme="light"] .highlight .nl { color: #7f4707 } /* Name.Label */ +html[data-theme="light"] .highlight .nn { color: #080808 } /* Name.Namespace */ +html[data-theme="light"] .highlight .nx { color: #080808 } /* Name.Other */ +html[data-theme="light"] .highlight .py { color: #005b82 } /* Name.Property */ +html[data-theme="light"] .highlight .nt { color: #005b82 } /* Name.Tag */ +html[data-theme="light"] .highlight .nv { color: #a12236 } /* Name.Variable */ +html[data-theme="light"] .highlight .ow { color: #6730c5 } /* Operator.Word */ +html[data-theme="light"] .highlight .pm { color: #080808 } /* Punctuation.Marker */ +html[data-theme="light"] .highlight .w { color: #080808 } /* Text.Whitespace */ +html[data-theme="light"] .highlight .mb { color: #7f4707 } /* Literal.Number.Bin */ +html[data-theme="light"] .highlight .mf { color: #7f4707 } /* Literal.Number.Float */ +html[data-theme="light"] .highlight .mh { color: #7f4707 } /* Literal.Number.Hex */ +html[data-theme="light"] .highlight .mi { color: #7f4707 } /* Literal.Number.Integer */ +html[data-theme="light"] .highlight .mo { color: #7f4707 } /* Literal.Number.Oct */ +html[data-theme="light"] .highlight .sa { color: #00622f } /* Literal.String.Affix */ +html[data-theme="light"] .highlight .sb { color: #00622f } /* Literal.String.Backtick */ +html[data-theme="light"] .highlight .sc { color: #00622f } /* Literal.String.Char */ +html[data-theme="light"] .highlight .dl { color: #00622f } /* Literal.String.Delimiter */ +html[data-theme="light"] .highlight .sd { color: #00622f } /* Literal.String.Doc */ +html[data-theme="light"] .highlight .s2 { color: #00622f } /* Literal.String.Double */ +html[data-theme="light"] .highlight .se { color: #00622f } /* Literal.String.Escape */ +html[data-theme="light"] .highlight .sh { color: #00622f } /* Literal.String.Heredoc */ +html[data-theme="light"] .highlight .si { color: #00622f } /* Literal.String.Interpol */ +html[data-theme="light"] .highlight .sx { color: #00622f } /* Literal.String.Other */ +html[data-theme="light"] .highlight .sr { color: #a12236 } /* Literal.String.Regex */ +html[data-theme="light"] .highlight .s1 { color: #00622f } /* Literal.String.Single */ +html[data-theme="light"] .highlight .ss { color: #005b82 } /* Literal.String.Symbol */ +html[data-theme="light"] .highlight .bp { color: #7f4707 } /* Name.Builtin.Pseudo */ +html[data-theme="light"] .highlight .fm { color: #005b82 } /* Name.Function.Magic */ +html[data-theme="light"] .highlight .vc { color: #a12236 } /* Name.Variable.Class */ +html[data-theme="light"] .highlight .vg { color: #a12236 } /* Name.Variable.Global */ +html[data-theme="light"] .highlight .vi { color: #a12236 } /* Name.Variable.Instance */ +html[data-theme="light"] .highlight .vm { color: #7f4707 } /* Name.Variable.Magic */ +html[data-theme="light"] .highlight .il { color: #7f4707 } /* Literal.Number.Integer.Long */ html[data-theme="dark"] .highlight pre { line-height: 125%; } html[data-theme="dark"] .highlight td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } html[data-theme="dark"] .highlight span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } diff --git a/docs/about.html b/docs/about.html index 8c96f9c07..f62066e2e 100644 --- a/docs/about.html +++ b/docs/about.html @@ -9,7 +9,7 @@ - About — Brevitas 0.10.2 documentation + About — Brevitas 0.11.0 documentation @@ -125,8 +125,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
diff --git a/docs/api_reference/brevitas.core.bit_width.html b/docs/api_reference/brevitas.core.bit_width.html index c2a0a5090..226dccfa2 100644 --- a/docs/api_reference/brevitas.core.bit_width.html +++ b/docs/api_reference/brevitas.core.bit_width.html @@ -9,7 +9,7 @@ - brevitas.core.bit_width package — Brevitas 0.10.2 documentation + brevitas.core.bit_width package — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +
@@ -447,11 +447,11 @@

Submodules
class brevitas.core.bit_width.const.BitWidthConst(bit_width, dtype=None, device=None)[source]#
-

Bases: Module

+

Bases: Module

ScriptModule that returns a constant bit-width wrapped in a float torch.tensor.

Parameters:
-

bit_width (int) – bit-width value.

+

bit_width (int) – bit-width value.

Examples

@@ -472,9 +472,9 @@

Submodules
forward()[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -489,12 +489,12 @@

Submodules
class brevitas.core.bit_width.const.BitWidthStatefulConst(bit_width, dtype=None, device=None)[source]#
-

Bases: Module

+

Bases: Module

ScriptModule that returns a constant bit-width wrapped in a float torch.tensor but retains the bit-width as part of the module state.

Parameters:
-

bit_width (int) – bit-width value.

+

bit_width (int) – bit-width value.

Examples

@@ -517,9 +517,9 @@

Submodules
forward()[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -534,13 +534,13 @@

Submodules
class brevitas.core.bit_width.const.MsbClampBitWidth(bit_width_to_remove_impl, min_overall_bit_width, max_overall_bit_width)[source]#
-

Bases: Module

+

Bases: Module

forward(input_bit_width)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -558,15 +558,15 @@

Submodules
class brevitas.core.bit_width.parameter.BitWidthParameter(bit_width, min_bit_width=2, restrict_bit_width_impl=IntRestrictValue(   (float_to_int_impl): RoundSte() ), override_pretrained_bit_width=False, dtype=None, device=None)[source]#
-

Bases: Module

+

Bases: Module

ScriptModule that returns a learnable bit-width wrapped in a float torch.Tensor.

Parameters:
    -
  • bit_width (int) – value to initialize the output learned bit-width.

  • -
  • min_bit_width (int) – lower bound for the output learned bit-width. Default: 2.

  • -
  • restrict_bit_width_impl (Module) – restrict the learned bit-width to a subset of values. Default: IntRestrictValue(RoundSte()).

  • -
  • override_pretrained_bit_width (bool) – ignore pretrained bit-width loaded from a state dict. Default: False.

  • +
  • bit_width (int) – value to initialize the output learned bit-width.

  • +
  • min_bit_width (int) – lower bound for the output learned bit-width. Default: 2.

  • +
  • restrict_bit_width_impl (Module) – restrict the learned bit-width to a subset of values. Default: IntRestrictValue(RoundSte()).

  • +
  • override_pretrained_bit_width (bool) – ignore pretrained bit-width loaded from a state dict. Default: False.

Returns:
@@ -576,7 +576,7 @@

Submodules

Tensor

Raises:
-

RuntimeError – if bit_width < min_bit_width.

+

RuntimeError – if bit_width < min_bit_width.

Examples

@@ -597,9 +597,9 @@

Submodules
forward()[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -614,13 +614,13 @@

Submodules
class brevitas.core.bit_width.parameter.RemoveBitwidthParameter(bit_width_to_remove, override_pretrained_bit_width=False, non_zero_epsilon=1e-06, remove_zero_bit_width=0.1, dtype=None, device=None)[source]#
-

Bases: Module

+

Bases: Module

forward()[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within diff --git a/docs/api_reference/brevitas.core.function_wrapper.html b/docs/api_reference/brevitas.core.function_wrapper.html index 458e945a3..26217f590 100644 --- a/docs/api_reference/brevitas.core.function_wrapper.html +++ b/docs/api_reference/brevitas.core.function_wrapper.html @@ -9,7 +9,7 @@ - brevitas.core.function_wrapper package — Brevitas 0.10.2 documentation + brevitas.core.function_wrapper package — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

@@ -448,7 +448,7 @@

Submodules
class brevitas.core.function_wrapper.clamp.ClampMin(min_val)[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for clamp_min().

Examples

>>> clamp_min = ClampMin(min_val=-2.0)
@@ -459,7 +459,7 @@ 

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -472,11 +472,45 @@

Submodules +
+class brevitas.core.function_wrapper.clamp.FloatClamp(tensor_clamp_impl, signed, inf_values=None, nan_values=None, max_available_float=None, saturating=True, device=None, dtype=None)[source]#
+

Bases: Module

+

” +ScriptModule for clamping minifloat formats to their inf/NaN implementations.

+

Currently, inf/NaN codes have to be encoded through the mantissa. +I.e. setting inf to 1101.111 (E4M3) is not a valid code.

+
+
+forward(x, exponent_bit_width, mantissa_bit_width, exponent_bias)[source]#
+

Define the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +
+
+inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)[source]#
+
+ +
+
+saturating_clamp(x, max_value, min_value)[source]#
+
+ +

+
class brevitas.core.function_wrapper.clamp.ScalarClamp(min_val, max_val)[source]#
-

Bases: Module

-

ScriptModule wrapper for clamp().

+

Bases: Module

+

ScriptModule wrapper for clamp().

Examples

>>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0)
 >>> scalar_clamp(torch.tensor([-3.0, 3.0]))
@@ -486,7 +520,7 @@ 

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -502,7 +536,7 @@

Submodules
class brevitas.core.function_wrapper.clamp.TensorClamp[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for tensor_clamp().

Examples

>>> tensor_clamp = TensorClamp()
@@ -515,7 +549,7 @@ 

Submodules
forward(x, min_val, max_val)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -535,7 +569,7 @@

Submodules
class brevitas.core.function_wrapper.misc.Identity[source]#
-

Bases: Module

+

Bases: Module

Identity ScriptModule.

Examples

>>> identity = Identity()
@@ -548,9 +582,9 @@ 

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -565,7 +599,7 @@

Submodules
class brevitas.core.function_wrapper.misc.InplaceLogTwo[source]#
-

Bases: Module

+

Bases: Module

Module wrapper for log2_().

Examples

>>> inplace_log_two = InplaceLogTwo()
@@ -580,9 +614,9 @@ 

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -597,8 +631,8 @@

Submodules
class brevitas.core.function_wrapper.misc.LogTwo[source]#
-

Bases: Module

-

ScriptModule wrapper for log2().

+

Bases: Module

+

ScriptModule wrapper for log2().

Examples

>>> log_two = LogTwo()
 >>> x = torch.tensor(8.0)
@@ -609,9 +643,9 @@ 

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -626,7 +660,7 @@

Submodules
class brevitas.core.function_wrapper.misc.PowerOfTwo[source]#
-

Bases: Module

+

Bases: Module

ScriptModule implementation of 2.0 ** x.

Examples

>>> power_of_two = PowerOfTwo()
@@ -638,9 +672,9 @@ 

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses. -:rtype: Tensor

+:rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within @@ -659,12 +693,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.CeilSte[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for ceil_ste().

forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -680,12 +714,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.DPURoundSte[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for dpu_round_ste().

forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -701,12 +735,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.FloorSte[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for floor_ste().

forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -722,12 +756,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.InplaceTensorClampSte[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for tensor_clamp_ste_().

forward(x, min_val, max_val)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -743,12 +777,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.RoundSte[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for round_ste().

forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -764,12 +798,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.RoundToZeroSte[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for round_to_zero_ste().

forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -785,12 +819,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.ScalarClampMinSte(min_val)[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for scalar_clamp_min_ste().

forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -806,12 +840,12 @@

Submodules
class brevitas.core.function_wrapper.ops_ste.TensorClampSte[source]#
-

Bases: Module

+

Bases: Module

ScriptModule wrapper for tensor_clamp_ste().

forward(x, min_val, max_val)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -828,10 +862,30 @@

Submodules

brevitas.core.function_wrapper.shape module#

ScriptModule classes to compute the view of a tensor according to various different criteria.

+
+
+class brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView(group_size, group_dim)[source]#
+

Bases: Module

+
+
+forward(x)[source]#
+

Define the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +
+
class brevitas.core.function_wrapper.shape.OverBatchOverOutputChannelView(permute_dims=None)[source]#
-

Bases: Module

+

Bases: Module

ScriptModule to compute the over_batch_over_output_channels() view of an input tensor.

Examples

@@ -844,7 +898,7 @@

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -860,7 +914,7 @@

Submodules
class brevitas.core.function_wrapper.shape.OverBatchOverTensorView(permute_dims=None)[source]#
-

Bases: Module

+

Bases: Module

ScriptMoodule to compute the over_batch_over_tensor() view of an input tensor.

Examples

@@ -873,7 +927,7 @@

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -889,7 +943,7 @@

Submodules
class brevitas.core.function_wrapper.shape.OverOutputChannelView(permute_dims=None)[source]#
-

Bases: Module

+

Bases: Module

ScriptMoodule to compute the over_output_channels() view of an input tensor.

Examples

@@ -902,7 +956,7 @@

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -918,7 +972,7 @@

Submodules
class brevitas.core.function_wrapper.shape.OverOutputFeaturesView(permute_dims=None)[source]#
-

Bases: Module

+

Bases: Module

ScriptModule to compute the over_output_features() view of an input tensor.

Examples

@@ -931,7 +985,27 @@

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+

+ +

+ +
+
+class brevitas.core.function_wrapper.shape.OverSubChannelBlockView(expanded_groupwise_shape, group_size, group_dim)[source]#
+

Bases: Module

+
+
+forward(x)[source]#
+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -947,7 +1021,7 @@

Submodules
class brevitas.core.function_wrapper.shape.OverTensorView[source]#
-

Bases: Module

+

Bases: Module

ScriptMoodule to compute the over_tensor() view of an input tensor.

Examples

>>> view_module = OverTensorView()
@@ -959,7 +1033,7 @@ 

Submodules
forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -975,11 +1049,11 @@

Submodules
class brevitas.core.function_wrapper.shape.PermuteDims(permute_dims)[source]#
-

Bases: Module

+

Bases: Module

forward(x)[source]#
-

Defines the computation performed at every call.

+

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

@@ -995,9 +1069,15 @@

Submodules
class brevitas.core.function_wrapper.shape.StatsInputViewShapeImpl[source]#
-

Bases: object

+

Bases: object

Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. All adhere to the same interface.

+
+
+DYNAMIC_OVER_SUBCHANNEL_BLOCK#
+

alias of DynamicOverSubChannelBlockView

+
+
OVER_BATCH_OVER_OUTPUT_CHANNELS#
@@ -1022,6 +1102,12 @@

SubmodulesOverOutputFeaturesView

+
+
+OVER_SUBCHANNEL_BLOCK#
+

alias of OverSubChannelBlockView

+
+
OVER_TENSOR#
@@ -1085,6 +1171,12 @@

SubmodulesClampMin.forward() +
  • FloatClamp +
  • ScalarClamp @@ -1150,6 +1242,10 @@

    Submodulesbrevitas.core.function_wrapper.shape module

  • +
  • OverSubChannelBlockView +
  • OverTensorView @@ -1175,10 +1275,12 @@

    SubmodulesStatsInputViewShapeImpl

  • diff --git a/docs/api_reference/brevitas.core.html b/docs/api_reference/brevitas.core.html index a95aa4cd4..7302ead4b 100644 --- a/docs/api_reference/brevitas.core.html +++ b/docs/api_reference/brevitas.core.html @@ -9,7 +9,7 @@ - brevitas.core package — Brevitas 0.10.2 documentation + brevitas.core package — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    @@ -478,6 +478,12 @@

    SubpackagesClampMin.forward() +
  • FloatClamp +
  • ScalarClamp @@ -543,6 +549,10 @@

    Subpackagesbrevitas.core.function_wrapper.shape module

  • +
  • OverSubChannelBlockView +
  • OverTensorView @@ -568,10 +582,12 @@

    SubpackagesStatsInputViewShapeImpl

  • @@ -668,6 +684,10 @@

    Subpackagesbrevitas.core.scaling.runtime module +
  • HalfQuadraticOptimizerScale +
  • +
  • HalfQuadraticOptimizerZeroPoint +
  • KLMinimizerThreshold
  • +
  • masked_median()
  • +
  • shrink_lp_op()
  • brevitas.core.stats.stats_wrapper module
  • @@ -790,13 +824,23 @@

    Submodules
    class brevitas.core.restrict_val.FloatRestrictValue[source]#
    -

    Bases: Module

    +

    Bases: Module

    +
    +
    +combine_scale_threshold(x, threshold)[source]#
    +
    +
    Return type:
    +

    Tensor

    +
    +
    +
    +
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -811,7 +855,7 @@

    Submodulesrestrict_init_float(x)[source]#
    Return type:
    -

    float

    +

    float

    @@ -831,7 +875,7 @@

    Submodulesrestrict_init_tensor(x)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -841,11 +885,21 @@

    Submodules
    class brevitas.core.restrict_val.IntRestrictValue(restrict_value_float_to_int_impl=RoundSte())[source]#
    -

    Bases: Module

    +

    Bases: Module

    +
    +
    +combine_scale_threshold(x, threshold)[source]#
    +
    +
    Return type:
    +

    Tensor

    +
    +
    +
    +
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -881,11 +935,21 @@

    Submodules
    class brevitas.core.restrict_val.LogFloatRestrictValue[source]#
    -

    Bases: Module

    +

    Bases: Module

    +
    +
    +combine_scale_threshold(x, threshold)[source]#
    +
    +
    Return type:
    +

    Tensor

    +
    +
    +
    +
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -921,11 +985,21 @@

    Submodules
    class brevitas.core.restrict_val.PowerOfTwoRestrictValue(restrict_value_float_to_int_impl=RoundSte())[source]#
    -

    Bases: Module

    +

    Bases: Module

    +
    +
    +combine_scale_threshold(x, threshold)[source]#
    +
    +
    Return type:
    +

    Tensor

    +
    +
    +
    +
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -964,11 +1038,11 @@

    Submodules
    class brevitas.core.utils.ParameterWrapper(value)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward()[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -984,11 +1058,11 @@

    Submodules
    class brevitas.core.utils.SingleArgStatelessBuffer(value)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(placeholder)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -1004,13 +1078,13 @@

    Submodules
    class brevitas.core.utils.SliceTensor[source]#
    -

    Bases: Module

    +

    Bases: Module

    eager_forward(x)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -1018,9 +1092,9 @@

    Submodules
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -1035,11 +1109,11 @@

    Submodules
    class brevitas.core.utils.StatelessBuffer(value)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward()[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -1053,7 +1127,7 @@

    Submodules
    state_dict(destination=None, prefix='', keep_vars=False)[source]#
    -

    Returns a dictionary containing references to the whole state of the module.

    +

    Return a dictionary containing references to the whole state of the module.

    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

    @@ -1077,13 +1151,13 @@

    Submodules
    Parameters:
      -
    • destination (dict, optional) – If provided, the state of module will +

    • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

    • -
    • prefix (str, optional) – a prefix added to parameter and buffer +

    • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

    • -
    • keep_vars (bool, optional) – by default the Tensor s +

    • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

    • @@ -1093,7 +1167,7 @@

      Submodules

      a dictionary containing a whole state of the module

    Return type:
    -

    dict

    +

    dict

    Example:

    @@ -1111,7 +1185,7 @@

    Submodulesbrevitas.core.utils.inplace_momentum_update(tensor, update, momentum, counter, new_counter)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -1121,7 +1195,7 @@

    Submodulesbrevitas.core.utils.inplace_tensor_add(tensor, value)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -1131,7 +1205,7 @@

    Submodulesbrevitas.core.utils.inplace_tensor_mul(tensor, value)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -1142,13 +1216,13 @@

    Submodules
    class brevitas.core.zero_point.ParameterFromRuntimeZeroPoint(collect_stats_steps, int_quant, quantize_zero_point, zero_point_stats_impl, zero_point_shape, zero_point_stats_input_view_shape_impl, zero_point_stats_momentum=0.1, dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x, scale, bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -1161,7 +1235,7 @@

    Submodules
    state_dict(destination=None, prefix='', keep_vars=False)[source]#
    -

    Returns a dictionary containing references to the whole state of the module.

    +

    Return a dictionary containing references to the whole state of the module.

    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

    @@ -1185,13 +1259,13 @@

    Submodules
    Parameters:
      -
    • destination (dict, optional) – If provided, the state of module will +

    • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

    • -
    • prefix (str, optional) – a prefix added to parameter and buffer +

    • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

    • -
    • keep_vars (bool, optional) – by default the Tensor s +

    • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

    • @@ -1201,7 +1275,7 @@

      Submodules

      a dictionary containing a whole state of the module

    Return type:
    -

    dict

    +

    dict

    Example:

    @@ -1217,7 +1291,7 @@

    Submodulestraining_forward(x)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -1227,15 +1301,15 @@

    Submodules
    class brevitas.core.zero_point.ParameterFromStatsFromParameterZeroPoint(int_quant, quantize_zero_point, zero_point_stats_input_view_shape_impl, zero_point_stats_input_concat_dim, zero_point_stats_impl, zero_point_shape, tracked_parameter_list, dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    ScriptModule implementation of a learned scale factor initialized from statistics of a parameter, e.g. weights MSE or AbsMax.

    forward(x, scale, bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -1248,7 +1322,7 @@

    Submodules
    state_dict(destination=None, prefix='', keep_vars=False)[source]#
    -

    Returns a dictionary containing references to the whole state of the module.

    +

    Return a dictionary containing references to the whole state of the module.

    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

    @@ -1272,13 +1346,13 @@

    Submodules
    Parameters:
      -
    • destination (dict, optional) – If provided, the state of module will +

    • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

    • -
    • prefix (str, optional) – a prefix added to parameter and buffer +

    • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

    • -
    • keep_vars (bool, optional) – by default the Tensor s +

    • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

    • @@ -1288,7 +1362,7 @@

      Submodules

      a dictionary containing a whole state of the module

    Return type:
    -

    dict

    +

    dict

    Example:

    @@ -1304,13 +1378,13 @@

    Submodules
    class brevitas.core.zero_point.ParameterZeroPoint(zero_point_init, int_quant, quantize_zero_point, zero_point_shape=None, dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x, scale, bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -1325,15 +1399,15 @@

    Submodules
    class brevitas.core.zero_point.PreZeroCenterZeroPoint(stats_reduce_dim, pre_zero_point_stats_input_view_shape_impl, pre_zero_point_shape=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    Experimental ScriptModule implementation of a pre-scaling zero-point that zero-centers the incoming tensors. This is intended to be used with DecoupledIntQuant.

    forward(x, scale, bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -1348,7 +1422,7 @@

    Submodulesget_zero_center(x)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -1358,13 +1432,13 @@

    Submodules
    class brevitas.core.zero_point.StatsFromParameterZeroPoint(int_quant, quantize_zero_point, zero_point_stats_input_view_shape_impl, zero_point_stats_input_concat_dim, zero_point_stats_impl, zero_point_shape, tracked_parameter_list)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x, scale, bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -1379,13 +1453,13 @@

    Submodules
    class brevitas.core.zero_point.ZeroZeroPoint(dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x, scale, bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -1450,6 +1524,7 @@

    SubmodulesSubmodules
  • brevitas.core.restrict_val module
  • @@ -446,15 +446,15 @@

    Submodules

    brevitas.core.quant.binary module#

    -class brevitas.core.quant.binary.BinaryQuant(scaling_impl, quant_delay_steps=0)[source]#
    -

    Bases: Module

    +class brevitas.core.quant.binary.BinaryQuant(scaling_impl, signed=True, quant_delay_steps=0)[source]# +

    Bases: Module

    ScriptModule that implements scaled uniform binary quantization of an input tensor. Quantization is performed with binary_sign_ste().

    Parameters:
    • scaling_impl (Module) – Module that returns a scale factor.

    • -
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    • +
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    Returns:
    @@ -490,9 +490,9 @@

    Submodules
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -507,7 +507,7 @@

    Submodules
    class brevitas.core.quant.binary.ClampedBinaryQuant(scaling_impl, tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#
    -

    Bases: Module

    +

    Bases: Module

    ScriptModule that implements scaled uniform binary quantization of an input tensor. Before going through quantization, the input tensor is clamped between (- scale, scale), which on the backward pass zeroes gradients corresponding to inputs outside that range. @@ -517,7 +517,7 @@

    Submodules

    Returns:
    @@ -559,9 +559,9 @@

    Submodules
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -579,13 +579,13 @@

    Submodules
    class brevitas.core.quant.delay.DelayWrapper(quant_delay_steps)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x, y)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -603,13 +603,13 @@

    Submodules
    class brevitas.core.quant.int.DecoupledRescalingIntQuant(decoupled_int_quant, pre_scaling_impl, scaling_impl, int_scaling_impl, pre_zero_point_impl, zero_point_impl, bit_width_impl)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -628,9 +628,9 @@

    Submodules
    forward(x, input_bit_width, input_is_signed)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -645,13 +645,13 @@

    Submodules
    class brevitas.core.quant.int.PrescaledRestrictIntQuant(int_quant, bit_width_impl)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x, scale)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -666,7 +666,7 @@

    Submodules
    class brevitas.core.quant.int.PrescaledRestrictIntQuantWithInputBitWidth(int_quant, bit_width_impl)[source]#
    -

    Bases: Module

    +

    Bases: Module

    ScriptModule that wraps around an integer quantization implementation like IntQuant. Zero-point is set to zero, scale is taken as input, bit-width is computed from an input bit-width.

    @@ -715,9 +715,9 @@

    Submodules
    forward(x, scale, input_bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -732,7 +732,7 @@

    Submodules
    class brevitas.core.quant.int.RescalingIntQuant(int_quant, scaling_impl, int_scaling_impl, zero_point_impl, bit_width_impl)[source]#
    -

    Bases: Module

    +

    Bases: Module

    ScriptModule that wraps around an integer quantization implementation like IntQuant. Scale, zero-point and bit-width are returned from their respective implementations and passed on to the integer quantization implementation.

    @@ -794,9 +794,9 @@

    Submodules
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -811,13 +811,13 @@

    Submodules
    class brevitas.core.quant.int.TruncIntQuant(float_to_int_impl, bit_width_impl, quant_delay_steps=0)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x, scale, zero_point, input_bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within @@ -834,19 +834,19 @@

    Submodules

    brevitas.core.quant.int_base module#

    -class brevitas.core.quant.int_base.DecoupledIntQuant(narrow_range, signed, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#
    -

    Bases: Module

    +class brevitas.core.quant.int_base.DecoupledIntQuant(narrow_range, signed, input_view_impl, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]# +

    Bases: Module

    ScriptModule that implements scale, shifted, uniform integer quantization of an input tensor, according to an input pre-scale, scale, pre-zero-point, zero-point and bit-width.

    Parameters:
      -
    • narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.

    • -
    • signed (bool) – Flag that determines whether to quantize to a signed range or not.

    • +
    • narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.

    • +
    • signed (bool) – Flag that determines whether to quantize to a signed range or not.

    • float_to_int_impl (Module) – Module that performs the conversion from floating point to integer representation. Default: RoundSte()

    • tensor_clamp_impl (Module) – Module that performs clamping. Default: TensorClamp()

    • -
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    • +
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    Returns:
    @@ -874,9 +874,9 @@

    Submodules
    forward(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -901,7 +901,7 @@

    Submodulesto_int(pre_scale, pre_zero_point, bit_width, x)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -910,19 +910,19 @@

    Submodules
    -class brevitas.core.quant.int_base.IntQuant(narrow_range, signed, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#
    -

    Bases: Module

    +class brevitas.core.quant.int_base.IntQuant(narrow_range, signed, input_view_impl, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]# +

    Bases: Module

    ScriptModule that implements scale, shifted, uniform integer quantization of an input tensor, according to an input scale, zero-point and bit-width.

    Parameters:
      -
    • narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.

    • -
    • signed (bool) – Flag that determines whether to quantize to a signed range or not.

    • +
    • narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.

    • +
    • signed (bool) – Flag that determines whether to quantize to a signed range or not.

    • float_to_int_impl (Module) – Module that performs the conversion from floating point to integer representation. Default: RoundSte()

    • tensor_clamp_impl (Module) – Module that performs clamping. Default: TensorClamp()

    • -
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    • +
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    Returns:
    @@ -953,9 +953,9 @@

    Submodules
    forward(scale, zero_point, bit_width, x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -980,7 +980,7 @@

    Submodulesto_int(scale, zero_point, bit_width, x)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -993,15 +993,15 @@

    Submodules
    class brevitas.core.quant.ternary.TernaryQuant(scaling_impl, threshold, quant_delay_steps=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    ScriptModule that implements scaled uniform ternary quantization of an input tensor. Quantization is performed with ternary_sign_ste().

    Parameters:
    • scaling_impl (Module) – Module that returns a scale factor.

    • -
    • threshold (float) – Ternarization threshold w.r.t. to the scale factor.

    • -
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    • +
    • threshold (float) – Ternarization threshold w.r.t. to the scale factor.

    • +
    • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

    Returns:
    @@ -1041,9 +1041,9 @@

    Submodules
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    +:rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

    Note

    Although the recipe for forward pass needs to be defined within diff --git a/docs/api_reference/brevitas.core.scaling.html b/docs/api_reference/brevitas.core.scaling.html index 642ea6328..cf7a68d21 100644 --- a/docs/api_reference/brevitas.core.scaling.html +++ b/docs/api_reference/brevitas.core.scaling.html @@ -9,7 +9,7 @@ - brevitas.core.scaling package — Brevitas 0.10.2 documentation + brevitas.core.scaling package — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    @@ -447,13 +447,13 @@

    Submodules
    class brevitas.core.scaling.int_scaling.IntScaling(signed, narrow_range)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -468,13 +468,13 @@

    Submodules
    class brevitas.core.scaling.int_scaling.PowerOfTwoIntScaling(signed)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(bit_width)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -489,15 +489,37 @@

    Submodules

    brevitas.core.scaling.runtime module#

    +
    +
    +class brevitas.core.scaling.runtime.RuntimeDynamicGroupStatsScaling(group_size, group_dim, input_view_impl, scaling_stats_impl, scaling_min_val, restrict_scaling_impl=FloatRestrictValue())[source]#
    +

    Bases: Module

    +
    +
    +forward(stats_input, threshold=None)[source]#
    +

    Define the computation performed at every call.

    +

    Should be overridden by all subclasses. +:rtype: Tensor

    +
    +

    Note

    +

    Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

    +
    +
    + +
    +
    -class brevitas.core.scaling.runtime.RuntimeStatsScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, restrict_scaling_impl, scaling_shape, affine_rescaling=False, affine_shift_scale=False, scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]#
    -

    Bases: Module

    +class brevitas.core.scaling.runtime.RuntimeStatsScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_shape, affine_rescaling=False, affine_shift_scale=False, restrict_scaling_impl=FloatRestrictValue(), scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]# +

    Bases: Module

    -forward(x)[source]#
    -

    Defines the computation performed at every call.

    -

    Should be overridden by all subclasses.

    +forward(x, threshold=None)[source]# +

    Define the computation performed at every call.

    +

    Should be overridden by all subclasses. +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -511,14 +533,14 @@

    Submodules
    -class brevitas.core.scaling.runtime.StatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, restrict_scaling_impl, scaling_shape, affine_rescaling=False, affine_shift_scale=False, scaling_min_val=None, dtype=None, device=None)[source]#
    -

    Bases: Module

    +class brevitas.core.scaling.runtime.StatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, scaling_shape, restrict_scaling_impl=FloatRestrictValue(), affine_rescaling=False, affine_shift_scale=False, scaling_min_val=None, dtype=None, device=None)[source]# +

    Bases: Module

    -forward(ignored)[source]#
    -

    Defines the computation performed at every call.

    +forward(ignored, threshold=None)[source]# +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -535,15 +557,15 @@

    Submodules

    brevitas.core.scaling.standalone module#

    -class brevitas.core.scaling.standalone.ConstScaling(scaling_init, restrict_scaling_impl=None, scaling_min_val=None, dtype=None, device=None)[source]#
    -

    Bases: Module

    +class brevitas.core.scaling.standalone.ConstScaling(scaling_init, restrict_scaling_impl=FloatRestrictValue(), scaling_min_val=None, dtype=None, device=None)[source]# +

    Bases: Module

    ScriptModule implementation of a constant scale factor.

    Parameters:
      -
    • scaling_init (Union[float, Tensor]) – value to use as constant scale factor.

    • +
    • scaling_init (Union[float, Tensor]) – value to use as constant scale factor.

    • restrict_scaling_impl (Module) – restrict scaling_init according to some criteria. Default: None

    • -
    • scaling_min_val (float) – force a lower-bound on scaling_init. Default: None

    • +
    • scaling_min_val (float) – force a lower-bound on scaling_init. Default: None

    Returns:
    @@ -576,10 +598,10 @@

    Submodules
    -forward(placeholder)[source]#
    -

    Defines the computation performed at every call.

    +forward(placeholder, threshold=None)[source]# +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -593,8 +615,8 @@

    Submodules
    -class brevitas.core.scaling.standalone.ParameterFromRuntimeStatsScaling(collect_stats_steps, scaling_stats_impl, scaling_stats_input_view_shape_impl=OverBatchOverTensorView(   (permute_impl): Identity() ), scaling_shape=(), restrict_scaling_impl=None, scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]#
    -

    Bases: Module

    +class brevitas.core.scaling.standalone.ParameterFromRuntimeStatsScaling(collect_stats_steps, scaling_stats_impl, scaling_stats_input_view_shape_impl=OverBatchOverTensorView(   (permute_impl): Identity() ), scaling_shape=(), restrict_scaling_impl=FloatRestrictValue(), scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]# +

    Bases: Module

    ScriptModule implementation of a learned scale factor initialized from runtime statistics. The implementation works in two phases. During the first phase, statistics are collected in the same fashion as batchnorm, meaning that while the module is in training mode a set of per-batch @@ -605,15 +627,15 @@

    Submodules
    Parameters:
      -
    • collect_stats_steps (int) – Number of calls to the forward method in training mode to collect statistics for.

    • +
    • collect_stats_steps (int) – Number of calls to the forward method in training mode to collect statistics for.

    • scaling_stats_impl (Module) – Implementation of the statistics computed during the collection phase.

    • scaling_stats_input_view_shape_impl (Module) – Implementation of the view applied to the runtime input during the statistics collection phase. Default: OverBatchOverTensorView().

    • -
    • scaling_shape (Tuple[int, ...]) – shape of the torch.nn.Parameter used in the second phase. Default: SCALAR_SHAPE.

    • +
    • scaling_shape (Tuple[int, ...]) – shape of the torch.nn.Parameter used in the second phase. Default: SCALAR_SHAPE.

    • restrict_scaling_impl (Module) – restrict the learned scale factor according to some criteria. Default: None input before going into scaling_stats_input_view_shape_impl. Default: None

    • -
    • scaling_stats_momentum (Optional[float]) – float = Momentum for the statistics moving average. Default: DEFAULT_MOMENTUM.

    • -
    • scaling_min_val (float) – force a lower-bound on the learned scale factor. Default: None.

    • +
    • scaling_stats_momentum (Optional[float]) – float = Momentum for the statistics moving average. Default: DEFAULT_MOMENTUM.

    • +
    • scaling_min_val (float) – force a lower-bound on the learned scale factor. Default: None.

    Returns:
    @@ -623,7 +645,7 @@

    Submodules

    Tensor

    Raises:
    -

    RuntimeError – if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None

    +

    RuntimeError – if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None

    Examples

    @@ -649,10 +671,10 @@

    Submodules
    -forward(stats_input)[source]#
    -

    Defines the computation performed at every call.

    +forward(stats_input, threshold=None)[source]# +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -665,7 +687,7 @@

    Submodules
    state_dict(destination=None, prefix='', keep_vars=False)[source]#
    -

    Returns a dictionary containing references to the whole state of the module.

    +

    Return a dictionary containing references to the whole state of the module.

    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

    @@ -689,13 +711,13 @@

    Submodules
    Parameters:
      -
    • destination (dict, optional) – If provided, the state of module will +

    • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

    • -
    • prefix (str, optional) – a prefix added to parameter and buffer +

    • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

    • -
    • keep_vars (bool, optional) – by default the Tensor s +

    • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

    • @@ -705,7 +727,7 @@

      Submodules

      a dictionary containing a whole state of the module

    Return type:
    -

    dict

    +

    dict

    Example:

    @@ -718,10 +740,10 @@

    Submodules
    -training_forward(stats_input)[source]#
    +training_forward(stats_input, threshold)[source]#
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -730,16 +752,16 @@

    Submodules
    -class brevitas.core.scaling.standalone.ParameterFromStatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, restrict_scaling_impl, scaling_shape, scaling_min_val=None, dtype=None, device=None)[source]#
    -

    Bases: Module

    +class brevitas.core.scaling.standalone.ParameterFromStatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, scaling_shape, restrict_scaling_impl=FloatRestrictValue(), scaling_min_val=None, dtype=None, device=None)[source]# +

    Bases: Module

    ScriptModule implementation of a learned scale factor initialized from statistics of a parameter, e.g. weights MSE or AbsMax.

    -forward(ignored)[source]#
    -

    Defines the computation performed at every call.

    +forward(ignored, threshold=None)[source]# +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -752,7 +774,7 @@

    Submodules
    state_dict(destination=None, prefix='', keep_vars=False)[source]#
    -

    Returns a dictionary containing references to the whole state of the module.

    +

    Return a dictionary containing references to the whole state of the module.

    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

    @@ -776,13 +798,13 @@

    Submodules
    Parameters:
      -
    • destination (dict, optional) – If provided, the state of module will +

    • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

    • -
    • prefix (str, optional) – a prefix added to parameter and buffer +

    • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

    • -
    • keep_vars (bool, optional) – by default the Tensor s +

    • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

    • @@ -792,7 +814,7 @@

      Submodules

      a dictionary containing a whole state of the module

    Return type:
    -

    dict

    +

    dict

    Example:

    @@ -807,16 +829,16 @@

    Submodules
    -class brevitas.core.scaling.standalone.ParameterScaling(scaling_init, scaling_shape=None, restrict_scaling_impl=None, scaling_min_val=None, dtype=None, device=None)[source]#
    -

    Bases: Module

    +class brevitas.core.scaling.standalone.ParameterScaling(scaling_init, scaling_shape=None, restrict_scaling_impl=FloatRestrictValue(), scaling_min_val=None, dtype=None, device=None)[source]# +

    Bases: Module

    ScriptModule implementation of a learned scale factor.

    Parameters:
      -
    • scaling_init (Union[float, Tensor]) – value to initialize the learned scale factor.

    • -
    • scaling_shape (Tuple[int, ...]) – shape to extend a scalar float or tensor scaling_init. Default: None

    • +
    • scaling_init (Union[float, Tensor]) – value to initialize the learned scale factor.

    • +
    • scaling_shape (Tuple[int, ...]) – shape to extend a scalar float or tensor scaling_init. Default: None

    • restrict_scaling_impl (Module) – restrict the learned scale factor according to some criteria. Default: None

    • -
    • scaling_min_val (float) – force a lower-bound on the learned scale factor. Default: None

    • +
    • scaling_min_val (float) – force a lower-bound on the learned scale factor. Default: None

    Returns:
    @@ -826,7 +848,7 @@

    Submodules

    Tensor

    Raises:
    -

    RuntimeError – if scaling_init is a non-scalar tensor and scaling_shape is != scaling_init.shape.

    +

    RuntimeError – if scaling_init is a non-scalar tensor and scaling_shape is != scaling_init.shape.

    Examples

    @@ -857,10 +879,10 @@

    Submodules
    -forward(placeholder)[source]#
    -

    Defines the computation performed at every call.

    +forward(placeholder, threshold=None)[source]# +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -934,6 +956,10 @@

    Submodulesbrevitas.core.scaling.runtime module

    @@ -447,11 +447,11 @@

    Submodules
    class brevitas.core.stats.stats_op.AbsAve(stats_reduce_dim=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -467,11 +467,11 @@

    Submodules
    class brevitas.core.stats.stats_op.AbsMax(stats_reduce_dim=None, keepdim=False)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -487,11 +487,11 @@

    Submodules
    class brevitas.core.stats.stats_op.AbsMaxAve(stats_reduce_dim)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -507,11 +507,11 @@

    Submodules
    class brevitas.core.stats.stats_op.AbsMaxL2(stats_reduce_dim)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -527,11 +527,11 @@

    Submodules
    class brevitas.core.stats.stats_op.AbsMinMax(stats_reduce_dim=None, keepdim=False, dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -547,11 +547,11 @@

    Submodules
    class brevitas.core.stats.stats_op.AbsPercentile(high_percentile_q, stats_reduce_dim, percentile_q=None, keepdim=False)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -564,16 +564,76 @@

    Submodules +
    +class brevitas.core.stats.stats_op.HalfQuadraticOptimizerScale(proxy_module, hqo_init_op_scale, keepdim, inner_stats_input_view_shape_impl, scaling_min_val=None, stats_reduce_dim=None, int_scaling_impl=None, bit_width_impl=None, hqo_beta_scale=100000.0, hqo_kappa_scale=1.01, hqo_lp_norm_scale=0.7, hqo_iters_scale=1000)[source]#
    +

    Bases: Module

    +
    +
    +forward(x)[source]#
    +

    Define the computation performed at every call.

    +

    Should be overridden by all subclasses.

    +
    +

    Note

    +

    Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

    +
    +
    + +
    +
    +optimize(x)[source]#
    +
    + +
    + +
    + +

    + +
    +
    +class brevitas.core.stats.stats_op.HalfQuadraticOptimizerZeroPoint(proxy_module, keepdim, hqo_init_op_zp, inner_stats_input_view_shape_impl, stats_reduce_dim=None, hqo_beta_zp=1.0, hqo_kappa_zp=1.01, hqo_lp_norm_zp=0.5, hqo_iters_zp=1000)[source]#
    +

    Bases: Module

    +
    +
    +forward(x)[source]#
    +

    Define the computation performed at every call.

    +

    Should be overridden by all subclasses.

    +
    +

    Note

    +

    Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

    +
    +
    + +
    +
    +optimize(x)[source]#
    +
    + +
    + +
    + +
    +
    class brevitas.core.stats.stats_op.KLMinimizerThreshold(signed, bit_width_impl, num_bins=1001, smoothing_eps=0.0001)[source]#
    -

    Bases: Module

    +

    Bases: Module

    Based on: apache/incubator-mxnet

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -594,13 +654,13 @@

    Submodules
    class brevitas.core.stats.stats_op.L1Norm(stats_reduce_dim=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    ScriptModule implementation to collect per-channel L1 normalization stats for weight normalization-based quantization.

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -616,13 +676,13 @@

    Submodules
    class brevitas.core.stats.stats_op.L2Norm(stats_reduce_dim=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    ScriptModule implementation to collect per-channel L2 normalization stats for weight normalization-based quantization.

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -638,7 +698,7 @@

    Submodules
    class brevitas.core.stats.stats_op.MSE(proxy_module, mse_init_op, inner_stats_input_view_shape_impl, stats_reduce_dim=None, mse_search_method='fibonacci', mse_iters=20)[source]#
    -

    Bases: Module

    +

    Bases: Module

    evaluate_loss(x, candidate)[source]#
    @@ -647,7 +707,7 @@

    Submodules
    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -683,11 +743,11 @@

    Submodules
    class brevitas.core.stats.stats_op.MeanLearnedSigmaStd(sigma, stats_output_shape, stats_reduce_dim=None, std_dev_epsilon=1e-08, dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -703,11 +763,11 @@

    Submodules
    class brevitas.core.stats.stats_op.MeanSigmaStd(sigma, stats_reduce_dim=None, std_dev_epsilon=1e-08, dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note

    @@ -723,13 +783,13 @@

    Submodules
    class brevitas.core.stats.stats_op.NegativeMinOrZero(stats_reduce_dim=None, dtype=None, device=None, keepdim=False)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -744,13 +804,13 @@

    Submodules
    class brevitas.core.stats.stats_op.NegativePercentileOrZero(low_percentile_q, stats_reduce_dim=None, dtype=None, device=None, keepdim=False)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -765,13 +825,13 @@

    Submodules
    class brevitas.core.stats.stats_op.PercentileInterval(low_percentile_q, high_percentile_q, stats_reduce_dim=None, keepdim=False, dtype=None, device=None)[source]#
    -

    Bases: Module

    +

    Bases: Module

    forward(x)[source]#
    -

    Defines the computation performed at every call.

    +

    Define the computation performed at every call.

    Should be overridden by all subclasses. -:rtype: Tensor

    +:rtype: Tensor

    Note

    Although the recipe for forward pass needs to be defined within @@ -783,6 +843,40 @@

    Submodules +
    +brevitas.core.stats.stats_op.masked_median(x, mask, dim=None, keepdim=False)[source]#
    +

    Compute the median of tensor x along dim, ignoring values where mask is False. +x and mask need to be broadcastable.

    +
    +
    Parameters:
    +
      +
    • x (Tensor) – Tensor to compute median of.

    • +
    • mask (BoolTensor) – Same shape as x with True where x is valid and False +where x should be masked. Mask should not be all False in any column of +dimension dim to avoid NaNs from zero division.

    • +
    • dim (int, optional) – Dimension to take median of. Defaults to 0.

    • +
    +
    +
    Returns:
    +

    Same shape as x, except dimension dim reduced.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    +

    + +
    +
    +brevitas.core.stats.stats_op.shrink_lp_op(x, beta, lp_norm)[source]#
    +
    +
    Return type:
    +

    Tensor

    +
    +
    +
    +

    brevitas.core.stats.stats_wrapper module#

    @@ -864,6 +958,18 @@

    SubmodulesAbsPercentile.forward() +
  • HalfQuadraticOptimizerScale +
  • +
  • HalfQuadraticOptimizerZeroPoint +
  • KLMinimizerThreshold
  • +
  • masked_median()
  • +
  • shrink_lp_op()
  • brevitas.core.stats.stats_wrapper module
  • diff --git a/docs/api_reference/brevitas.function.html b/docs/api_reference/brevitas.function.html index 0a7b9e7ec..143a5e932 100644 --- a/docs/api_reference/brevitas.function.html +++ b/docs/api_reference/brevitas.function.html @@ -9,7 +9,7 @@ - brevitas.function package — Brevitas 0.10.2 documentation + brevitas.function package — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    @@ -496,7 +496,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -535,8 +535,8 @@

    Submodules
    Parameters:
      -
    • signed (bool) – Indicates whether the represented integer is signed or not.

    • -
    • narrow_range (bool) – Indicates whether to narrow the maximum unsigned value represented by 1.

    • +
    • signed (bool) – Indicates whether the represented integer is signed or not.

    • +
    • narrow_range (bool) – Indicates whether to narrow the maximum unsigned value represented by 1.

    • bit_width (Tensor) – Number of bits available for the representation.

    @@ -567,8 +567,8 @@

    Submodules
    Parameters:
      -
    • signed (bool) – Indicates whether the represented integer is signed or not.

    • -
    • narrow_range (bool) – Indicates whether to narrow the minimum value represented by 1.

    • +
    • signed (bool) – Indicates whether the represented integer is signed or not.

    • +
    • narrow_range (bool) – Indicates whether to narrow the minimum value represented by 1.

    • bit_width (Tensor) – Number of bits available for the representation.

    @@ -621,9 +621,9 @@

    Submodules
    Parameters:
      -
    • x (Tensor) – Input on which to apply the clamp operation

    • -
    • min_val (Tensor) – Minimum values for the clamp operation.

    • -
    • max_val (Tensor) – Maximum values for the clamp operation.

    • +
    • x (Tensor) – Input on which to apply the clamp operation

    • +
    • min_val (Tensor) – Minimum values for the clamp operation.

    • +
    • max_val (Tensor) – Maximum values for the clamp operation.

    @@ -633,7 +633,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    Returns:

    Input x clamped between the provided minimum and maximum tensors.

    @@ -653,7 +653,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    @@ -670,11 +670,11 @@

    Submodules
    brevitas.function.ops_ste.abs_binary_sign_grad(x)[source]#
    -

    Function that implements torch.abs() with a binary-sign backward, in order to -have subgradient 1 in 0. Compare with torch.abs()’ subgradient of 0 in 0.

    +

    Function that implements torch.abs() with a binary-sign backward, in order to +have subgradient 1 in 0. Compare with torch.abs()’ subgradient of 0 in 0.

    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -701,7 +701,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -724,10 +724,10 @@

    Submodules
    brevitas.function.ops_ste.ceil_ste(x)[source]#
    -

    Function that implements torch.ceil() with a straight-through gradient estimator.

    +

    Function that implements torch.ceil() with a straight-through gradient estimator.

    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -753,7 +753,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -776,10 +776,10 @@

    Submodules
    brevitas.function.ops_ste.floor_ste(x)[source]#
    -

    Function that implements torch.floor() with a straight-through gradient estimator.

    +

    Function that implements torch.floor() with a straight-through gradient estimator.

    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -801,10 +801,10 @@

    Submodules
    brevitas.function.ops_ste.round_ste(x)[source]#
    -

    Function that implements torch.round() with a straight-through gradient estimator.

    +

    Function that implements torch.round() with a straight-through gradient estimator.

    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -830,7 +830,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -859,8 +859,8 @@

    Submodules
    Parameters:
      -
    • x (Tensor) – input tensor to clamp.

    • -
    • min_val (float) – scalar value to use as lower bound for the input tensor.

    • +
    • x (Tensor) – input tensor to clamp.

    • +
    • min_val (float) – scalar value to use as lower bound for the input tensor.

    Returns:
    @@ -890,15 +890,15 @@

    Submodules
    brevitas.function.ops_ste.scalar_clamp_ste(x, min_val, max_val)[source]#
    -

    Function that implements torch.clamp() with a straight-through gradient estimator +

    Function that implements torch.clamp() with a straight-through gradient estimator for the gradient of the output w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

    Parameters:
      -
    • x (Tensor) – input tensor to clamp.

    • -
    • min_val (float) – scalar value to use as lower bound for the input tensor.

    • -
    • max_val (float) – scalar value to use as upper bound for the input tensor.

    • +
    • x (Tensor) – input tensor to clamp.

    • +
    • min_val (float) – scalar value to use as lower bound for the input tensor.

    • +
    • max_val (float) – scalar value to use as upper bound for the input tensor.

    Returns:
    @@ -933,7 +933,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -961,7 +961,7 @@

    Submodules
    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -985,10 +985,10 @@

    Submodules
    brevitas.function.ops_ste.ternary_sign_ste(x)[source]#
    -

    Function that implements torch.sign() with a straight-through gradient estimator.

    +

    Function that implements torch.sign() with a straight-through gradient estimator.

    Return type:
    -

    Tensor

    +

    Tensor

    Notes

    @@ -1043,7 +1043,7 @@

    Submodules

    x (Tensor) – Input tensor with batches at dimension 0.

    Return type:
    -

    Tuple[int, int]

    +

    Tuple[int, int]

    Returns:

    A tuple containing the 2-dim shape.

    @@ -1065,7 +1065,7 @@

    Submodules
    Return type:
    -

    Tuple[int, int]

    +

    Tuple[int, int]

    Returns:

    A tuple containing the 2-dim shape.

    @@ -1107,7 +1107,7 @@

    Submodules

    x (Tensor) – Input tensor.

    Return type:
    -

    int

    +

    int

    Returns:

    The number -1 corresponding to a flat shape.

    diff --git a/docs/api_reference/brevitas.ops.html b/docs/api_reference/brevitas.ops.html index e7dbd7ca2..33ceb49f3 100644 --- a/docs/api_reference/brevitas.ops.html +++ b/docs/api_reference/brevitas.ops.html @@ -9,7 +9,7 @@ - brevitas.ops package — Brevitas 0.10.2 documentation + brevitas.ops package — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    @@ -445,9 +445,9 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.AbsBinarySignGradFn(*args, **kwargs)[source]#
    -

    Bases: Function

    -

    Autograd function that implements torch.abs() with a binary-sign backward, in order to -have subgradient 1 in 0. Compare with torch.abs()’ subgradient of 0 in 0.

    +

    Bases: Function

    +

    Autograd function that implements torch.abs() with a binary-sign backward, in order to +have subgradient 1 in 0. Compare with torch.abs()’ subgradient of 0 in 0.

    AbsBinarySignGradFn.apply(*args) is first aliased to abs_binary_sign_grad(*args) and then wrapped by abs_binary_sign_grad() when env BREVITAS_JIT=0. See abs_binary_sign_grad() for details on the interface and @@ -457,7 +457,7 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.BinarySignSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    +

    Bases: Function

    Autograd function that implements binary_sign() with a straight-through gradient estimator.

    BinarySignSteFn.apply(*args) is first aliased to @@ -470,8 +470,8 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.CeilSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    -

    Autograd function that implements torch.ceil() with a straight-through gradient estimator.

    +

    Bases: Function

    +

    Autograd function that implements torch.ceil() with a straight-through gradient estimator.

    CeilSteFn.apply(*args) is first aliased to ceil_ste_impl(*args) and then wrapped by ceil_ste() when env BREVITAS_JIT=0. See ceil_ste() for details on the interface and @@ -481,7 +481,7 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.DPURoundSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    +

    Bases: Function

    Autograd function that implements dpu_round() with a straight-through gradient estimator.

    DPURoundSteFn.apply(*args) is first aliased to dpu_round_ste_impl(*args) and then wrapped by @@ -493,8 +493,8 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.FloorSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    -

    Autograd function that implements torch.floor() with a straight-through gradient estimator.

    +

    Bases: Function

    +

    Autograd function that implements torch.floor() with a straight-through gradient estimator.

    FloorSteFn.apply(*args) is first aliased to floor_ste_impl(*args) and then wrapped by floor_ste() when env BREVITAS_JIT=0. See floor_ste() for details on the interface and @@ -504,7 +504,7 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.InplaceTensorClampSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    +

    Bases: Function

    Autograd function that implements tensor_clamp_() with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

    @@ -518,8 +518,8 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.RoundSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    -

    Autograd function that implements torch.round() with a straight-through gradient +

    Bases: Function

    +

    Autograd function that implements torch.round() with a straight-through gradient estimator.

    RoundSteFn.apply(*args) is first aliased to round_ste_impl(*args) and then wrapped by round_ste() when env BREVITAS_JIT=0. @@ -529,7 +529,7 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.RoundToZeroSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    +

    Bases: Function

    Autograd function that implements round_to_zero() with a straight-through gradient estimator.

    RoundToZeroSteFn.apply(*args) is first aliased to round_to_zero_ste_impl(*args) and then wrapped by @@ -541,7 +541,7 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.ScalarClampMinSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    +

    Bases: Function

    Autograd function that implements torch.clamp_min with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val is always None.

    @@ -554,7 +554,7 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.ScalarClampSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    +

    Bases: Function

    Autograd function that implements torch.clamp with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and min_val are always None.

    @@ -567,7 +567,7 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.TensorClampSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    +

    Bases: Function

    Autograd function that implements tensor_clamp() with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

    @@ -580,8 +580,8 @@

    Submodules
    class brevitas.ops.autograd_ste_ops.TernarySignSteFn(*args, **kwargs)[source]#
    -

    Bases: Function

    -

    Autograd function that implements torch.sign() with a straight-through gradient estimator.

    +

    Bases: Function

    +

    Autograd function that implements torch.sign() with a straight-through gradient estimator.

    TernarySignSteFn.apply(*args) is first aliased to ternary_sign_ste_impl(*args) and then wrapped by ternary_sign_ste() when env BREVITAS_JIT=0. See ternary_sign_ste() for details on the interface and @@ -590,67 +590,67 @@

    Submodules
    -brevitas.ops.autograd_ste_ops.abs_binary_sign_grad_impl()#
    +brevitas.ops.autograd_ste_ops.abs_binary_sign_grad_impl(*args, **kwargs)#

    Alias for AbsBinarySignGradFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.binary_sign_ste_impl()#
    +brevitas.ops.autograd_ste_ops.binary_sign_ste_impl(*args, **kwargs)#

    Alias for BinarySignSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.ceil_ste_impl()#
    +brevitas.ops.autograd_ste_ops.ceil_ste_impl(*args, **kwargs)#

    Alias for CeilSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.dpu_round_ste_impl()#
    +brevitas.ops.autograd_ste_ops.dpu_round_ste_impl(*args, **kwargs)#

    Alias for DPURoundSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.floor_ste_impl()#
    +brevitas.ops.autograd_ste_ops.floor_ste_impl(*args, **kwargs)#

    Alias for FloorSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.round_ste_impl()#
    +brevitas.ops.autograd_ste_ops.round_ste_impl(*args, **kwargs)#

    Alias for RoundSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.round_to_zero_ste_impl()#
    +brevitas.ops.autograd_ste_ops.round_to_zero_ste_impl(*args, **kwargs)#

    Alias for RoundToZeroSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.scalar_clamp_min_ste_impl()#
    +brevitas.ops.autograd_ste_ops.scalar_clamp_min_ste_impl(*args, **kwargs)#

    Alias for ScalarClampMinSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.scalar_clamp_ste_impl()#
    +brevitas.ops.autograd_ste_ops.scalar_clamp_ste_impl(*args, **kwargs)#

    Alias for ScalarClampSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl()#
    +brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl(*args, **kwargs)#

    Alias for TensorClampSteFn.apply(*args)

    -brevitas.ops.autograd_ste_ops.ternary_sign_ste_impl()#
    +brevitas.ops.autograd_ste_ops.ternary_sign_ste_impl(*args, **kwargs)#

    Alias for TernarySignSteFn.apply(*args)

    diff --git a/docs/api_reference/index.html b/docs/api_reference/index.html index d7a42f0f1..d4b7ecd58 100644 --- a/docs/api_reference/index.html +++ b/docs/api_reference/index.html @@ -9,7 +9,7 @@ - API reference — Brevitas 0.10.2 documentation + API reference — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    diff --git a/docs/architecture.html b/docs/architecture.html index 452d6b035..ab4853118 100644 --- a/docs/architecture.html +++ b/docs/architecture.html @@ -9,7 +9,7 @@ - Architecture — Brevitas 0.10.2 documentation + Architecture — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    diff --git a/docs/faq.html b/docs/faq.html index afd6d724c..0f2006b35 100644 --- a/docs/faq.html +++ b/docs/faq.html @@ -9,7 +9,7 @@ - F.A.Q. — Brevitas 0.10.2 documentation + F.A.Q. — Brevitas 0.11.0 documentation @@ -126,8 +126,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    diff --git a/docs/genindex.html b/docs/genindex.html index ac066b5e1..9e835720a 100644 --- a/docs/genindex.html +++ b/docs/genindex.html @@ -8,7 +8,7 @@ - Index — Brevitas 0.10.2 documentation + Index — Brevitas 0.11.0 documentation @@ -123,8 +123,8 @@ - Brevitas 0.10.2 documentation - Home - + Brevitas 0.11.0 documentation - Home +

    @@ -389,6 +389,7 @@

    Index

    | E | F | G + | H | I | K | L @@ -671,14 +672,24 @@

    C

  • CeilSte (class in brevitas.core.function_wrapper.ops_ste)
  • - - + @@ -695,10 +706,10 @@

    D

  • DelayWrapper (class in brevitas.core.quant.delay)
  • - - + @@ -725,6 +740,8 @@

    E

    F

    +

    H

    + + + +
    +

    I