diff --git a/docs/.buildinfo b/docs/.buildinfo index 028bec665..ae0de1084 100644 --- a/docs/.buildinfo +++ b/docs/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: d9a08019dec6882195fa0ef4f685a5cd +config: 029b3c75722303c51813ef033f4d07df tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/_modules/brevitas/core/bit_width/const.html b/docs/_modules/brevitas/core/bit_width/const.html index 63cecaf10..476d7ce41 100644 --- a/docs/_modules/brevitas/core/bit_width/const.html +++ b/docs/_modules/brevitas/core/bit_width/const.html @@ -8,7 +8,7 @@
-
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""
- def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):
+ def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
super(BinaryQuant, self).__init__()
+ assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
diff --git a/docs/_modules/brevitas/core/quant/delay.html b/docs/_modules/brevitas/core/quant/delay.html
index 347941eea..a401098b0 100644
--- a/docs/_modules/brevitas/core/quant/delay.html
+++ b/docs/_modules/brevitas/core/quant/delay.html
@@ -8,7 +8,7 @@
- brevitas.core.quant.delay — Brevitas 0.10.2 documentation
+ brevitas.core.quant.delay — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
+ self.observer_only = brevitas.jit.Attribute(False, bool)
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
bit_width = self.msb_clamp_bit_width_impl()
- threshold = self.scaling_impl(x)
int_threshold = self.int_scaling_impl(bit_width)
- scale = threshold / int_threshold
+ scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
- y = self.int_quant(scale, zero_point, bit_width, x)
+ if self.observer_only:
+ y = x
+ else:
+ y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width
self.pre_zero_point_impl = pre_zero_point_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
+ self.observer_only = brevitas.jit.Attribute(False, bool)
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
@@ -594,10 +598,12 @@ Source code for brevitas.core.quant.int
pre_threshold = self.pre_scaling_impl(x)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
- threshold = self.scaling_impl(x)
- scale = threshold / int_threshold
+ scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
- y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
+ if self.observer_only:
+ y = x
+ else:
+ y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
@@ -660,10 +666,12 @@ Source code for brevitas.core.quant.int
pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
- threshold = self.scaling_impl(x)
- scale = threshold / int_threshold
+ scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
- y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
+ if self.observer_only:
+ y = x
+ else:
+ y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
self,
narrow_range: bool,
signed: bool,
+ input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
@@ -470,9 +471,11 @@ Source code for brevitas.core.quant.int_base
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
+ self.input_view_impl = input_view_impl
[docs] @brevitas.jit.script_method
def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor:
+ x = self.input_view_impl(x)
y = x / scale
y = y + zero_point
min_int_val = self.min_int(bit_width)
@@ -534,6 +537,7 @@ Source code for brevitas.core.quant.int_base
self,
narrow_range: bool,
signed: bool,
+ input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
@@ -543,11 +547,13 @@ Source code for brevitas.core.quant.int_base
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
+ self.input_view_impl = input_view_impl
[docs] @brevitas.jit.script_method
def to_int(
self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor,
x: Tensor) -> Tensor:
+ x = self.input_view_impl(x)
y = x / pre_scale
y = y + pre_zero_point
min_int_val = self.min_int(bit_width)
diff --git a/docs/_modules/brevitas/core/quant/ternary.html b/docs/_modules/brevitas/core/quant/ternary.html
index 9ac3f7859..b77615c2c 100644
--- a/docs/_modules/brevitas/core/quant/ternary.html
+++ b/docs/_modules/brevitas/core/quant/ternary.html
@@ -8,7 +8,7 @@
- brevitas.core.quant.ternary — Brevitas 0.10.2 documentation
+ brevitas.core.quant.ternary — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
diff --git a/docs/_modules/brevitas/core/restrict_val.html b/docs/_modules/brevitas/core/restrict_val.html
index cdf92c705..724e23eaf 100644
--- a/docs/_modules/brevitas/core/restrict_val.html
+++ b/docs/_modules/brevitas/core/restrict_val.html
@@ -8,7 +8,7 @@
- brevitas.core.restrict_val — Brevitas 0.10.2 documentation
+ brevitas.core.restrict_val — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
@@ -446,7 +446,7 @@ Source code for brevitas.core.restrict_val
self.restrict_value_impl = Identity()
@brevitas.jit.script_method
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
x = self.clamp_min_ste(x)
return x
@@ -462,7 +462,7 @@ Source code for brevitas.core.restrict_val
self.restrict_value_impl = Identity()
@brevitas.jit.script_method
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
return x
@@ -478,7 +478,7 @@ Source code for brevitas.core.restrict_val
self.min_val = scaling_min_val
@brevitas.jit.script_method
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor):
x = self.clamp_min_ste(x)
return x
@@ -500,8 +500,11 @@ Source code for brevitas.core.restrict_val
+[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
+ return x / threshold
+
[docs] @brevitas.jit.script_method
- def forward(self, x: torch.Tensor) -> Tensor:
+ def forward(self, x: Tensor) -> Tensor:
return x
@@ -514,7 +517,7 @@ Source code for brevitas.core.restrict_val
-[docs] def restrict_init_tensor(self, x: torch.Tensor):
+
[docs] def restrict_init_module(self):
@@ -523,8 +526,11 @@ Source code for brevitas.core.restrict_val
+[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
+ return x - threshold
+
[docs] @brevitas.jit.script_method
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor):
x = self.power_of_two(x)
return x
@@ -538,7 +544,7 @@ Source code for brevitas.core.restrict_val
-[docs] def restrict_init_tensor(self, x: torch.Tensor):
+
[docs] def restrict_init_module(self):
@@ -547,8 +553,11 @@ Source code for brevitas.core.restrict_val
+[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
+ return x / threshold
+
[docs] @brevitas.jit.script_method
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
return x
@@ -563,7 +572,7 @@ Source code for brevitas.core.restrict_val
-[docs] def restrict_init_tensor(self, x: torch.Tensor):
+
[docs] def restrict_init_module(self):
@@ -572,8 +581,11 @@ Source code for brevitas.core.restrict_val
+[docs] def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
+ return x - threshold
+
[docs] @brevitas.jit.script_method
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x
diff --git a/docs/_modules/brevitas/core/scaling/int_scaling.html b/docs/_modules/brevitas/core/scaling/int_scaling.html
index ab0663819..3dfd7c93d 100644
--- a/docs/_modules/brevitas/core/scaling/int_scaling.html
+++ b/docs/_modules/brevitas/core/scaling/int_scaling.html
@@ -8,7 +8,7 @@
- brevitas.core.scaling.int_scaling — Brevitas 0.10.2 documentation
+ brevitas.core.scaling.int_scaling — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
diff --git a/docs/_modules/brevitas/core/scaling/runtime.html b/docs/_modules/brevitas/core/scaling/runtime.html
index 566a80cc3..31e87e7bc 100644
--- a/docs/_modules/brevitas/core/scaling/runtime.html
+++ b/docs/_modules/brevitas/core/scaling/runtime.html
@@ -8,7 +8,7 @@
- brevitas.core.scaling.runtime — Brevitas 0.10.2 documentation
+ brevitas.core.scaling.runtime — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
@@ -421,6 +421,7 @@ Source code for brevitas.core.scaling.runtime
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.restrict_val import _RestrictClampValue
+from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _RuntimeStats
from brevitas.core.stats import DEFAULT_MOMENTUM
@@ -437,8 +438,8 @@ Source code for brevitas.core.scaling.runtime
scaling_stats_input_view_shape_impl: Module,
scaling_stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter],
- restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
+ restrict_scaling_impl: Module = FloatRestrictValue(),
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
scaling_min_val: Optional[float] = None,
@@ -461,9 +462,12 @@ Source code for brevitas.core.scaling.runtime
device)
[docs] @brevitas.jit.script_method
- def forward(self, ignored: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats()
- return self.stats_scaling_impl(stats)
+ if threshold is None:
+ threshold = torch.ones(1).type_as(stats)
+ return self.stats_scaling_impl(stats, threshold)
class _StatsScaling(brevitas.jit.ScriptModule):
@@ -488,10 +492,16 @@ Source code for brevitas.core.scaling.runtime
self.affine_rescaling = Identity()
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
+ self.restrict_scaling_impl = restrict_scaling_impl
@brevitas.jit.script_method
- def forward(self, stats: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if threshold is None:
+ threshold = torch.ones(1).type_as(stats)
+ threshold = self.restrict_scaling_pre(threshold)
stats = self.restrict_scaling_pre(stats)
+ stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
return stats
@@ -503,10 +513,10 @@ Source code for brevitas.core.scaling.runtime
self,
scaling_stats_impl: Module,
scaling_stats_input_view_shape_impl: Module,
- restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
+ restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_stats_momentum: float = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
@@ -530,9 +540,9 @@ Source code for brevitas.core.scaling.runtime
device)
[docs] @brevitas.jit.script_method
- def forward(self, x: torch.Tensor):
+ def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.runtime_stats(x)
- return self.stats_scaling_impl(stats)
+ return self.stats_scaling_impl(stats, threshold)
class _AffineRescaling(brevitas.jit.ScriptModule):
@@ -568,6 +578,38 @@ Source code for brevitas.core.scaling.runtime
missing_keys.remove(affine_weight_key)
if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys:
missing_keys.remove(affine_bias_key)
+
+
+[docs]class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):
+
+ def __init__(
+ self,
+ group_size: int,
+ group_dim: int,
+ input_view_impl: Module,
+ scaling_stats_impl: Module,
+ scaling_min_val: Optional[float],
+ restrict_scaling_impl: Module = FloatRestrictValue()) -> None:
+ super(RuntimeDynamicGroupStatsScaling, self).__init__()
+ self.group_size = group_size
+ self.group_dim = group_dim
+ self.scaling_stats_impl = scaling_stats_impl
+ self.scaling_min_val = scaling_min_val
+ self.input_view_impl = input_view_impl
+ self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
+
+[docs] @brevitas.jit.script_method
+ def forward(
+ self,
+ stats_input: torch.Tensor,
+ threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if threshold is None:
+ threshold = torch.ones(1).type_as(stats_input)
+ stats_input_reshaped = self.input_view_impl(stats_input)
+ out = self.scaling_stats_impl(stats_input_reshaped) / threshold
+ # Scaling min val
+ out = self.restrict_clamp_scaling(out)
+ return out
diff --git a/docs/_modules/brevitas/core/scaling/standalone.html b/docs/_modules/brevitas/core/scaling/standalone.html
index 54ec5390e..2948d3552 100644
--- a/docs/_modules/brevitas/core/scaling/standalone.html
+++ b/docs/_modules/brevitas/core/scaling/standalone.html
@@ -8,7 +8,7 @@
- brevitas.core.scaling.standalone — Brevitas 0.10.2 documentation
+ brevitas.core.scaling.standalone — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
@@ -425,6 +425,7 @@ Source code for brevitas.core.scaling.standalone
from brevitas.core.restrict_val import _ClampValue
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import _RestrictValue
+from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.scaling.runtime import _StatsScaling
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _Stats
@@ -470,7 +471,7 @@ Source code for brevitas.core.scaling.standalone
def __init__(
self,
scaling_init: Union[float, Tensor],
- restrict_scaling_impl: Optional[Module] = None,
+ restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
@@ -478,18 +479,23 @@ Source code for brevitas.core.scaling.standalone
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
if isinstance(scaling_init, Tensor):
scaling_init = scaling_init.to(device=device, dtype=dtype)
- if restrict_scaling_impl is not None:
- scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
+ scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
+ self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
self.value = StatelessBuffer(scaling_init.detach())
else:
- if restrict_scaling_impl is not None:
- scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init)
+ scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init)
+ self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device))
[docs] @brevitas.jit.script_method
- def forward(self, placeholder: Tensor) -> Tensor:
- value = self.value()
- restricted_value = self.restrict_clamp_scaling(value)
+ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
+ if threshold is None:
+ threshold = torch.ones(1).type_as(placeholder)
+ # We first apply any restriction to scaling
+ # For IntQuant, this is no-op, retrocompatible.
+ threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold))
+ restricted_value = self.restrict_clamp_scaling(self.value())
+ restricted_value = restricted_value / threshold
return restricted_value
@@ -536,7 +542,7 @@ Source code for brevitas.core.scaling.standalone
self,
scaling_init: Union[float, Tensor],
scaling_shape: Optional[Tuple[int, ...]] = None,
- restrict_scaling_impl: Optional[Module] = None,
+ restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
@@ -551,17 +557,24 @@ Source code for brevitas.core.scaling.standalone
scaling_init = scaling_init.detach()
else:
scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device)
- if restrict_scaling_impl is not None:
- scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
+
+ scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
+ self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
+
if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None:
scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device)
self.value = Parameter(scaling_init)
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
[docs] @brevitas.jit.script_method
- def forward(self, placeholder: Tensor) -> Tensor:
+ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
+ if threshold is None:
+ threshold = torch.ones(1).type_as(placeholder)
+ # We first apply any restriction to scaling
+ # For IntQuant, this is no-op, retrocompatible.
+ threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold))
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
- return value
+ return value / threshold
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
@@ -588,8 +601,8 @@ Source code for brevitas.core.scaling.standalone
scaling_stats_input_view_shape_impl: Module,
scaling_stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter],
- restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
+ restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
@@ -600,30 +613,38 @@ Source code for brevitas.core.scaling.standalone
scaling_stats_input_view_shape_impl,
scaling_stats_input_concat_dim,
tracked_parameter_list)
+ self.restrict_scaling_impl = restrict_scaling_impl
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device)
self.init_done: bool = brevitas.jit.Attribute(False, bool)
self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool)
- if restrict_scaling_impl is not None:
- self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
- else:
- self.restrict_inplace_preprocess = Identity()
+ self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
+ self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))
[docs] @brevitas.jit.script_method
- def forward(self, ignored: torch.Tensor) -> torch.Tensor:
+ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
+ if threshold is None:
+ threshold = torch.ones(1).type_as(ignored)
+ # Threshold division must happen after we update self.value, but before we apply restrict_preproces
+ # This is because we don't want to store a parameter dependant on a runtime value (threshold)
+ # And because restrict needs to happen after we divide by threshold
if self.init_done:
- value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
+ threshold = self.restrict_inplace_preprocess(threshold)
+ value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
+ value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
return value
else:
stats = self.parameter_list_stats()
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
- return self.stats_scaling_impl(stats)
+ return self.stats_scaling_impl(stats, threshold)
stats = self.restrict_inplace_preprocess(stats)
+ threshold = self.restrict_inplace_preprocess(threshold)
inplace_tensor_mul(self.value.detach(), stats)
- value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
+ value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
+ value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
self.init_done = True
return value
@@ -631,7 +652,7 @@ Source code for brevitas.core.scaling.standalone
output_dict = super(ParameterFromStatsFromParameterScaling, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
# Avoid saving the init value
- if not self.init_done:
+ if not self.init_done and not config._FULL_STATE_DICT:
del output_dict[prefix + 'value']
return output_dict
@@ -700,7 +721,7 @@ Source code for brevitas.core.scaling.standalone
scaling_stats_impl: Module,
scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(),
scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
- restrict_scaling_impl: Optional[Module] = None,
+ restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
@@ -715,19 +736,19 @@ Source code for brevitas.core.scaling.standalone
scaling_stats_momentum, Optional[float])
self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device))
self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))
+ self.restrict_scaling_impl = restrict_scaling_impl
self.restrict_scaling = _RestrictValue(restrict_scaling_impl)
self.clamp_scaling = _ClampValue(scaling_min_val)
self.local_loss_mode: bool = brevitas.jit.Attribute(
False, bool) # required to support MSE eval or variants
- if restrict_scaling_impl is not None:
- self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
- self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
- else:
- self.restrict_inplace_preprocess = Identity()
- self.restrict_preprocess = Identity()
+ self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
+ self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
[docs] @brevitas.jit.script_method
- def training_forward(self, stats_input: Tensor) -> Tensor:
+ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
+ # Threshold division must happen after we update self.value, but before we apply restrict_preproces
+ # This is because we don't want to store a parameter dependent on a runtime value (threshold)
+ # And because restrict needs to happen after we divide by threshold
if self.counter < self.collect_stats_steps:
stats_input = self.stats_input_view_shape_impl(stats_input)
stats = self.stats(stats_input)
@@ -737,32 +758,41 @@ Source code for brevitas.core.scaling.standalone
new_counter = self.counter + 1
# Whenever we are in local loss mode, we don't update the counter nor the buffer
if self.local_loss_mode:
- return abs_binary_sign_grad(clamped_stats)
+ # Local loss mode, we early exit and divide by threshold
+ return abs_binary_sign_grad(clamped_stats / threshold)
if self.counter == 0:
inplace_tensor_mul(self.buffer, clamped_stats.detach())
else:
inplace_momentum_update(
self.buffer, clamped_stats.detach(), self.momentum, self.counter, new_counter)
self.counter = new_counter
- return abs_binary_sign_grad(clamped_stats)
+ return abs_binary_sign_grad(clamped_stats / threshold)
elif self.counter == self.collect_stats_steps:
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
+ threshold = self.restrict_preprocess(threshold)
+ value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
self.counter = self.counter + 1
- return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value)))
+ return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
else:
- return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value)))
+ threshold = self.restrict_preprocess(threshold)
+ value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
+ return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
[docs] @brevitas.jit.script_method
- def forward(self, stats_input: Tensor) -> Tensor:
+ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
+ if threshold is None:
+ threshold = torch.ones(1).type_as(stats_input)
if self.training:
- return self.training_forward(stats_input)
+ # Threshold division handled inside the training_forward
+ return self.training_forward(stats_input, threshold)
else:
if self.counter <= self.collect_stats_steps:
- out = self.buffer
+ out = self.buffer / threshold
out = self.restrict_preprocess(out)
else:
- out = self.value
+ threshold = self.restrict_preprocess(threshold)
+ out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out)))
return out
@@ -772,7 +802,7 @@ Source code for brevitas.core.scaling.standalone
# Avoid saving the buffer
del output_dict[prefix + 'buffer']
# Avoid saving the init value
- if self.counter == 0:
+ if self.counter == 0 and not config._FULL_STATE_DICT:
del output_dict[prefix + 'value']
# Save buffer into value for any non-zero number of collection steps
elif self.counter <= self.collect_stats_steps:
diff --git a/docs/_modules/brevitas/core/stats/stats_op.html b/docs/_modules/brevitas/core/stats/stats_op.html
index 5d95da826..18dffb1fc 100644
--- a/docs/_modules/brevitas/core/stats/stats_op.html
+++ b/docs/_modules/brevitas/core/stats/stats_op.html
@@ -8,7 +8,7 @@
- brevitas.core.stats.stats_op — Brevitas 0.10.2 documentation
+ brevitas.core.stats.stats_op — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
@@ -420,8 +420,11 @@ Source code for brevitas.core.stats.stats_op
import brevitas
from brevitas import config
+from brevitas.core.function_wrapper.misc import Identity
+from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
+from brevitas.quant_tensor import _unpack_quant_tensor
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
@@ -849,6 +852,19 @@ Source code for brevitas.core.stats.stats_op
m.local_loss_mode = enabled
+def _set_observer_mode(module, enabled, previous_observer_mode):
+ for m in module.modules():
+ if hasattr(m, 'observer_only'):
+ previous_observer_mode[m] = m.observer_only
+ m.observer_only = enabled
+
+
+def _restore_observer_mode(module, previous_observer_mode):
+ for m in module.modules():
+ if hasattr(m, 'observer_only'):
+ m.observer_only = previous_observer_mode[m]
+
+
[docs]class MSE(torch.nn.Module):
# References:
# https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py
@@ -866,7 +882,12 @@ Source code for brevitas.core.stats.stats_op
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
+ self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
+ self.set_observer_mode = lambda enabled: _set_observer_mode(
+ proxy_module, enabled, self.previous_observer_mode)
+ self.restore_observer_mode = lambda: _restore_observer_mode(
+ proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.num = mse_iters
self.search_method = mse_search_method
@@ -887,11 +908,12 @@ Source code for brevitas.core.stats.stats_op
self.internal_candidate = candidate
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
+ self.set_observer_mode(False)
quant_value = self.proxy_forward(x)
- if isinstance(quant_value, tuple):
- quant_value = quant_value[0]
+ quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
+ self.restore_observer_mode()
return loss
[docs] def mse_grid_search(self, xl, x):
@@ -954,6 +976,244 @@ Source code for brevitas.core.stats.stats_op
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
return self.internal_candidate
+
+
+[docs]class HalfQuadraticOptimizerScale(torch.nn.Module):
+ # References:
+ # https://mobiusml.github.io/hqq_blog/
+ # https://github.com/mobiusml/hqq?tab=readme-ov-file
+
+ def __init__(
+ self,
+ proxy_module,
+ hqo_init_op_scale,
+ keepdim: bool,
+ inner_stats_input_view_shape_impl: torch.nn.Module,
+ scaling_min_val: Optional[float] = None,
+ stats_reduce_dim: Optional[int] = None,
+ int_scaling_impl=None,
+ bit_width_impl=None,
+ hqo_beta_scale: float = 1e5,
+ hqo_kappa_scale: float = 1.01,
+ hqo_lp_norm_scale: float = .7,
+ hqo_iters_scale: int = 1000):
+ super(HalfQuadraticOptimizerScale, self).__init__()
+ self.hqo_init_op = hqo_init_op_scale
+ self.input_view_shape_impl = inner_stats_input_view_shape_impl
+ self.proxy_forward = proxy_module.forward
+ self.previous_observer_mode = dict()
+ self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
+ self.set_observer_mode = lambda enabled: _set_observer_mode(
+ proxy_module, enabled, self.previous_observer_mode)
+ self.restore_observer_mode = lambda: _restore_observer_mode(
+ proxy_module, self.previous_observer_mode)
+ self.internal_candidate = None
+ self.hqo_iters = hqo_iters_scale
+ self.stats_reduce_dim = stats_reduce_dim
+ self.local_loss_mode: bool = False
+
+ self.beta = hqo_beta_scale
+ self.kappa = hqo_kappa_scale
+ self.lp_norm = hqo_lp_norm_scale
+
+ self.int_scaling_impl = int_scaling_impl
+ self.msb_clamp_bit_width_impl = bit_width_impl
+ if scaling_min_val is not None and scaling_min_val != 0:
+ self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
+ else:
+ self.clamp_min_ste = Identity()
+ self.keepdim = keepdim
+
+[docs] def parameter_search(self, xl, x):
+ best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
+ candidate = xl
+ best_candidate = candidate
+ beta = self.beta
+ with torch.no_grad():
+ for i in range(0, self.hqo_iters):
+ self.internal_candidate = candidate
+ self.set_local_loss_mode(True)
+ self.set_observer_mode(False)
+ quant_tensor = self.proxy_forward(x).detach()
+ self.set_local_loss_mode(False)
+ self.restore_observer_mode()
+ loss = torch.abs(quant_tensor.value - x).mean()
+
+ best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
+ if loss >= best_loss:
+ break
+ best_loss = torch.min(loss, best_loss)
+ W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm)
+ zero_point = quant_tensor.zero_point
+ num = self.input_view_shape_impl(x - W_e).detach()
+ den = self.input_view_shape_impl(
+ torch.round(quant_tensor.value / quant_tensor.scale) - zero_point).detach()
+ mask = (num != 0.) & (den != 0.)
+ if self.stats_reduce_dim is None:
+ candidate = masked_median(num / den, mask)
+ else:
+ candidate = masked_median(
+ num / den, mask, dim=self.stats_reduce_dim, keepdim=self.keepdim)
+ candidate = candidate.type_as(self.internal_candidate)
+ candidate = self.clamp_min_ste(candidate)
+ bit_width = self.msb_clamp_bit_width_impl()
+ int_threshold = self.int_scaling_impl(bit_width)
+ candidate = candidate * int_threshold
+ candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)]
+ candidate[torch.isinf(candidate)] = self.internal_candidate[torch.isinf(candidate)]
+ beta *= self.kappa
+ return best_candidate
+
+[docs] def optimize(self, x):
+ x_view = self.input_view_shape_impl(x)
+
+ init = self.hqo_init_op(x_view).detach()
+ best_candidate = self.parameter_search(init, x_view)
+
+ # Save for evaluation by other modules (e.g. zp) invoking local loss mode
+ self.internal_candidate = best_candidate.detach()
+ torch.cuda.empty_cache()
+ return best_candidate
+
+[docs] def forward(self, x):
+ if not self.local_loss_mode:
+ with torch.no_grad():
+ return self.optimize(x)
+ else:
+ # This is invoked for the zero-point whenever scale is being optimized first
+ if self.internal_candidate is None:
+ x = self.input_view_shape_impl(x)
+ self.internal_candidate = self.hqo_init_op(x).detach()
+ return self.internal_candidate
+
+
+[docs]class HalfQuadraticOptimizerZeroPoint(torch.nn.Module):
+ # References:
+ # https://mobiusml.github.io/hqq_blog/
+ # https://github.com/mobiusml/hqq?tab=readme-ov-file
+
+ def __init__(
+ self,
+ proxy_module,
+ keepdim: bool,
+ hqo_init_op_zp: torch.nn.Module,
+ inner_stats_input_view_shape_impl: torch.nn.Module,
+ stats_reduce_dim: Optional[int] = None,
+ hqo_beta_zp: float = 1e0,
+ hqo_kappa_zp: float = 1.01,
+ hqo_lp_norm_zp: float = .5,
+ hqo_iters_zp: int = 1000):
+ super(HalfQuadraticOptimizerZeroPoint, self).__init__()
+ self.hqo_init_op_zp = hqo_init_op_zp
+ self.input_view_shape_impl = inner_stats_input_view_shape_impl
+ self.proxy_forward = proxy_module.forward
+ self.previous_observer_mode = dict()
+ self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
+ self.set_observer_mode = lambda enabled: _set_observer_mode(
+ proxy_module, enabled, self.previous_observer_mode)
+ self.restore_observer_mode = lambda: _restore_observer_mode(
+ proxy_module, self.previous_observer_mode)
+ self.internal_candidate = None
+ self.stats_reduce_dim = stats_reduce_dim
+ self.local_loss_mode: bool = False
+ self.beta = hqo_beta_zp
+ self.kappa = hqo_kappa_zp
+ self.lp_norm = hqo_lp_norm_zp
+ self.hqo_iters = hqo_iters_zp
+ self.keepdim = keepdim
+
+[docs] def parameter_search(self, xl, x):
+ best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
+ candidate = xl
+ best_candidate = candidate
+ with torch.no_grad():
+ for i in range(0, self.hqo_iters):
+ self.internal_candidate = candidate
+ self.set_local_loss_mode(True)
+ self.set_observer_mode(False)
+ quant_tensor = self.proxy_forward(x).detach()
+ self.set_local_loss_mode(False)
+ self.restore_observer_mode()
+ qt_value = self.input_view_shape_impl(quant_tensor.value)
+ qt_scale = self.input_view_shape_impl(quant_tensor.scale)
+ qt_zp = self.input_view_shape_impl(quant_tensor.zero_point)
+ qt_int = qt_value / qt_scale + qt_zp
+ loss = torch.abs(qt_value - x).mean()
+ best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
+ if loss >= best_loss:
+ break
+ best_loss = torch.min(loss, best_loss)
+ W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm)
+
+ # Compared to the original formulation, the value we're looking for is:
+ # - scaled by qt_scale
+ # - opposite sign
+ val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale)
+
+ if self.stats_reduce_dim is None:
+ candidate = torch.mean(val)
+ else:
+ candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=self.keepdim)
+ self.beta *= self.kappa
+ return best_candidate
+
+[docs] def optimize(self, x):
+ x_view = self.input_view_shape_impl(x)
+
+ init = self.hqo_init_op_zp(x_view).detach()
+
+ best_candidate = self.parameter_search(init, x)
+
+ # Save for evaluation by other modules (e.g. zp) invoking local loss mode
+ self.internal_candidate = best_candidate.detach()
+ torch.cuda.empty_cache()
+ return best_candidate
+
+[docs] def forward(self, x):
+ if not self.local_loss_mode:
+ with torch.no_grad():
+ return self.optimize(x)
+ else:
+ # This is invoked for the zero-point whenever scale is being optimized first
+ if self.internal_candidate is None:
+ x = self.input_view_shape_impl(x)
+ self.internal_candidate = self.hqo_init_op_zp(x).detach()
+ return self.internal_candidate
+
+
+[docs]def masked_median(x, mask, dim=None, keepdim=False):
+ """Compute the median of tensor x along dim, ignoring values where mask is False.
+ x and mask need to be broadcastable.
+
+ Args:
+ x (Tensor): Tensor to compute median of.
+ mask (BoolTensor): Same shape as x with True where x is valid and False
+ where x should be masked. Mask should not be all False in any column of
+ dimension dim to avoid NaNs from zero division.
+ dim (int, optional): Dimension to take median of. Defaults to 0.
+
+ Returns:
+ Tensor: Same shape as x, except dimension dim reduced.
+ """
+ # uncomment this assert for safety but might impact performance
+ # assert (
+ # mask.sum(dim=dim).ne(0).all()
+ # ), "mask should not be all False in any column, causes zero division"
+ x_nan = x.float().masked_fill(~mask, float("nan"))
+ if dim is None:
+ x_median = x_nan.nanmedian()
+ else:
+ x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)
+ return x_median
+
+
+# Shrinking operator
+[docs]def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor:
+ if lp_norm == 1:
+ return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
+ else:
+ return torch.sign(x) * torch.nn.functional.relu(
+ torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1))
diff --git a/docs/_modules/brevitas/core/utils.html b/docs/_modules/brevitas/core/utils.html
index 6c90c9602..a7378d4d7 100644
--- a/docs/_modules/brevitas/core/utils.html
+++ b/docs/_modules/brevitas/core/utils.html
@@ -8,7 +8,7 @@
- brevitas.core.utils — Brevitas 0.10.2 documentation
+ brevitas.core.utils — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
diff --git a/docs/_modules/brevitas/core/zero_point.html b/docs/_modules/brevitas/core/zero_point.html
index 53919717c..986cdc5a3 100644
--- a/docs/_modules/brevitas/core/zero_point.html
+++ b/docs/_modules/brevitas/core/zero_point.html
@@ -8,7 +8,7 @@
- brevitas.core.zero_point — Brevitas 0.10.2 documentation
+ brevitas.core.zero_point — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
@@ -691,8 +691,9 @@ Source code for brevitas.core.zero_point
output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
# Avoid saving the init value
- if not self.init_done:
- del output_dict[prefix + 'value']
+ if not self.init_done and not config._FULL_STATE_DICT:
+ del output_dict[prefix + 'value']
+ return output_dict
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
diff --git a/docs/_modules/brevitas/function/ops.html b/docs/_modules/brevitas/function/ops.html
index c325b40bd..53806a3b6 100644
--- a/docs/_modules/brevitas/function/ops.html
+++ b/docs/_modules/brevitas/function/ops.html
@@ -8,7 +8,7 @@
- brevitas.function.ops — Brevitas 0.10.2 documentation
+ brevitas.function.ops — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
@@ -599,7 +599,7 @@ Source code for brevitas.function.ops
return value
-[docs]@brevitas.jit.script
+[docs]@brevitas.jit.ignore
def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
diff --git a/docs/_modules/brevitas/function/ops_ste.html b/docs/_modules/brevitas/function/ops_ste.html
index 14b00539e..182924071 100644
--- a/docs/_modules/brevitas/function/ops_ste.html
+++ b/docs/_modules/brevitas/function/ops_ste.html
@@ -8,7 +8,7 @@
- brevitas.function.ops_ste — Brevitas 0.10.2 documentation
+ brevitas.function.ops_ste — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
diff --git a/docs/_modules/brevitas/function/shape.html b/docs/_modules/brevitas/function/shape.html
index 9b6e94515..cb64c5178 100644
--- a/docs/_modules/brevitas/function/shape.html
+++ b/docs/_modules/brevitas/function/shape.html
@@ -8,7 +8,7 @@
- brevitas.function.shape — Brevitas 0.10.2 documentation
+ brevitas.function.shape — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
diff --git a/docs/_modules/brevitas/ops/autograd_ste_ops.html b/docs/_modules/brevitas/ops/autograd_ste_ops.html
index 21b701fe8..dbb265da7 100644
--- a/docs/_modules/brevitas/ops/autograd_ste_ops.html
+++ b/docs/_modules/brevitas/ops/autograd_ste_ops.html
@@ -8,7 +8,7 @@
- brevitas.ops.autograd_ste_ops — Brevitas 0.10.2 documentation
+ brevitas.ops.autograd_ste_ops — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
diff --git a/docs/_modules/index.html b/docs/_modules/index.html
index 5380f2cf8..c1ddf62d6 100644
--- a/docs/_modules/index.html
+++ b/docs/_modules/index.html
@@ -8,7 +8,7 @@
- Overview: module code — Brevitas 0.10.2 documentation
+ Overview: module code — Brevitas 0.11.0 documentation
@@ -123,8 +123,8 @@
-
-
+
+
diff --git a/docs/_static/documentation_options.js b/docs/_static/documentation_options.js
index db6d22fe9..9dc22d647 100644
--- a/docs/_static/documentation_options.js
+++ b/docs/_static/documentation_options.js
@@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
- VERSION: '0.10.2',
+ VERSION: '0.11.0',
LANGUAGE: 'en',
COLLAPSE_INDEX: false,
BUILDER: 'html',
diff --git a/docs/_static/pygments.css b/docs/_static/pygments.css
index 997797f27..012e6a00a 100644
--- a/docs/_static/pygments.css
+++ b/docs/_static/pygments.css
@@ -3,77 +3,77 @@ html[data-theme="light"] .highlight td.linenos .normal { color: inherit; backgro
html[data-theme="light"] .highlight span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
html[data-theme="light"] .highlight td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
html[data-theme="light"] .highlight span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
-html[data-theme="light"] .highlight .hll { background-color: #7971292e }
-html[data-theme="light"] .highlight { background: #fefefe; color: #545454 }
-html[data-theme="light"] .highlight .c { color: #797129 } /* Comment */
-html[data-theme="light"] .highlight .err { color: #d91e18 } /* Error */
-html[data-theme="light"] .highlight .k { color: #7928a1 } /* Keyword */
-html[data-theme="light"] .highlight .l { color: #797129 } /* Literal */
-html[data-theme="light"] .highlight .n { color: #545454 } /* Name */
-html[data-theme="light"] .highlight .o { color: #008000 } /* Operator */
-html[data-theme="light"] .highlight .p { color: #545454 } /* Punctuation */
-html[data-theme="light"] .highlight .ch { color: #797129 } /* Comment.Hashbang */
-html[data-theme="light"] .highlight .cm { color: #797129 } /* Comment.Multiline */
-html[data-theme="light"] .highlight .cp { color: #797129 } /* Comment.Preproc */
-html[data-theme="light"] .highlight .cpf { color: #797129 } /* Comment.PreprocFile */
-html[data-theme="light"] .highlight .c1 { color: #797129 } /* Comment.Single */
-html[data-theme="light"] .highlight .cs { color: #797129 } /* Comment.Special */
-html[data-theme="light"] .highlight .gd { color: #007faa } /* Generic.Deleted */
+html[data-theme="light"] .highlight .hll { background-color: #fae4c2 }
+html[data-theme="light"] .highlight { background: #fefefe; color: #080808 }
+html[data-theme="light"] .highlight .c { color: #515151 } /* Comment */
+html[data-theme="light"] .highlight .err { color: #a12236 } /* Error */
+html[data-theme="light"] .highlight .k { color: #6730c5 } /* Keyword */
+html[data-theme="light"] .highlight .l { color: #7f4707 } /* Literal */
+html[data-theme="light"] .highlight .n { color: #080808 } /* Name */
+html[data-theme="light"] .highlight .o { color: #00622f } /* Operator */
+html[data-theme="light"] .highlight .p { color: #080808 } /* Punctuation */
+html[data-theme="light"] .highlight .ch { color: #515151 } /* Comment.Hashbang */
+html[data-theme="light"] .highlight .cm { color: #515151 } /* Comment.Multiline */
+html[data-theme="light"] .highlight .cp { color: #515151 } /* Comment.Preproc */
+html[data-theme="light"] .highlight .cpf { color: #515151 } /* Comment.PreprocFile */
+html[data-theme="light"] .highlight .c1 { color: #515151 } /* Comment.Single */
+html[data-theme="light"] .highlight .cs { color: #515151 } /* Comment.Special */
+html[data-theme="light"] .highlight .gd { color: #005b82 } /* Generic.Deleted */
html[data-theme="light"] .highlight .ge { font-style: italic } /* Generic.Emph */
-html[data-theme="light"] .highlight .gh { color: #007faa } /* Generic.Heading */
+html[data-theme="light"] .highlight .gh { color: #005b82 } /* Generic.Heading */
html[data-theme="light"] .highlight .gs { font-weight: bold } /* Generic.Strong */
-html[data-theme="light"] .highlight .gu { color: #007faa } /* Generic.Subheading */
-html[data-theme="light"] .highlight .kc { color: #7928a1 } /* Keyword.Constant */
-html[data-theme="light"] .highlight .kd { color: #7928a1 } /* Keyword.Declaration */
-html[data-theme="light"] .highlight .kn { color: #7928a1 } /* Keyword.Namespace */
-html[data-theme="light"] .highlight .kp { color: #7928a1 } /* Keyword.Pseudo */
-html[data-theme="light"] .highlight .kr { color: #7928a1 } /* Keyword.Reserved */
-html[data-theme="light"] .highlight .kt { color: #797129 } /* Keyword.Type */
-html[data-theme="light"] .highlight .ld { color: #797129 } /* Literal.Date */
-html[data-theme="light"] .highlight .m { color: #797129 } /* Literal.Number */
-html[data-theme="light"] .highlight .s { color: #008000 } /* Literal.String */
-html[data-theme="light"] .highlight .na { color: #797129 } /* Name.Attribute */
-html[data-theme="light"] .highlight .nb { color: #797129 } /* Name.Builtin */
-html[data-theme="light"] .highlight .nc { color: #007faa } /* Name.Class */
-html[data-theme="light"] .highlight .no { color: #007faa } /* Name.Constant */
-html[data-theme="light"] .highlight .nd { color: #797129 } /* Name.Decorator */
-html[data-theme="light"] .highlight .ni { color: #008000 } /* Name.Entity */
-html[data-theme="light"] .highlight .ne { color: #7928a1 } /* Name.Exception */
-html[data-theme="light"] .highlight .nf { color: #007faa } /* Name.Function */
-html[data-theme="light"] .highlight .nl { color: #797129 } /* Name.Label */
-html[data-theme="light"] .highlight .nn { color: #545454 } /* Name.Namespace */
-html[data-theme="light"] .highlight .nx { color: #545454 } /* Name.Other */
-html[data-theme="light"] .highlight .py { color: #007faa } /* Name.Property */
-html[data-theme="light"] .highlight .nt { color: #007faa } /* Name.Tag */
-html[data-theme="light"] .highlight .nv { color: #d91e18 } /* Name.Variable */
-html[data-theme="light"] .highlight .ow { color: #7928a1 } /* Operator.Word */
-html[data-theme="light"] .highlight .pm { color: #545454 } /* Punctuation.Marker */
-html[data-theme="light"] .highlight .w { color: #545454 } /* Text.Whitespace */
-html[data-theme="light"] .highlight .mb { color: #797129 } /* Literal.Number.Bin */
-html[data-theme="light"] .highlight .mf { color: #797129 } /* Literal.Number.Float */
-html[data-theme="light"] .highlight .mh { color: #797129 } /* Literal.Number.Hex */
-html[data-theme="light"] .highlight .mi { color: #797129 } /* Literal.Number.Integer */
-html[data-theme="light"] .highlight .mo { color: #797129 } /* Literal.Number.Oct */
-html[data-theme="light"] .highlight .sa { color: #008000 } /* Literal.String.Affix */
-html[data-theme="light"] .highlight .sb { color: #008000 } /* Literal.String.Backtick */
-html[data-theme="light"] .highlight .sc { color: #008000 } /* Literal.String.Char */
-html[data-theme="light"] .highlight .dl { color: #008000 } /* Literal.String.Delimiter */
-html[data-theme="light"] .highlight .sd { color: #008000 } /* Literal.String.Doc */
-html[data-theme="light"] .highlight .s2 { color: #008000 } /* Literal.String.Double */
-html[data-theme="light"] .highlight .se { color: #008000 } /* Literal.String.Escape */
-html[data-theme="light"] .highlight .sh { color: #008000 } /* Literal.String.Heredoc */
-html[data-theme="light"] .highlight .si { color: #008000 } /* Literal.String.Interpol */
-html[data-theme="light"] .highlight .sx { color: #008000 } /* Literal.String.Other */
-html[data-theme="light"] .highlight .sr { color: #d91e18 } /* Literal.String.Regex */
-html[data-theme="light"] .highlight .s1 { color: #008000 } /* Literal.String.Single */
-html[data-theme="light"] .highlight .ss { color: #007faa } /* Literal.String.Symbol */
-html[data-theme="light"] .highlight .bp { color: #797129 } /* Name.Builtin.Pseudo */
-html[data-theme="light"] .highlight .fm { color: #007faa } /* Name.Function.Magic */
-html[data-theme="light"] .highlight .vc { color: #d91e18 } /* Name.Variable.Class */
-html[data-theme="light"] .highlight .vg { color: #d91e18 } /* Name.Variable.Global */
-html[data-theme="light"] .highlight .vi { color: #d91e18 } /* Name.Variable.Instance */
-html[data-theme="light"] .highlight .vm { color: #797129 } /* Name.Variable.Magic */
-html[data-theme="light"] .highlight .il { color: #797129 } /* Literal.Number.Integer.Long */
+html[data-theme="light"] .highlight .gu { color: #005b82 } /* Generic.Subheading */
+html[data-theme="light"] .highlight .kc { color: #6730c5 } /* Keyword.Constant */
+html[data-theme="light"] .highlight .kd { color: #6730c5 } /* Keyword.Declaration */
+html[data-theme="light"] .highlight .kn { color: #6730c5 } /* Keyword.Namespace */
+html[data-theme="light"] .highlight .kp { color: #6730c5 } /* Keyword.Pseudo */
+html[data-theme="light"] .highlight .kr { color: #6730c5 } /* Keyword.Reserved */
+html[data-theme="light"] .highlight .kt { color: #7f4707 } /* Keyword.Type */
+html[data-theme="light"] .highlight .ld { color: #7f4707 } /* Literal.Date */
+html[data-theme="light"] .highlight .m { color: #7f4707 } /* Literal.Number */
+html[data-theme="light"] .highlight .s { color: #00622f } /* Literal.String */
+html[data-theme="light"] .highlight .na { color: #912583 } /* Name.Attribute */
+html[data-theme="light"] .highlight .nb { color: #7f4707 } /* Name.Builtin */
+html[data-theme="light"] .highlight .nc { color: #005b82 } /* Name.Class */
+html[data-theme="light"] .highlight .no { color: #005b82 } /* Name.Constant */
+html[data-theme="light"] .highlight .nd { color: #7f4707 } /* Name.Decorator */
+html[data-theme="light"] .highlight .ni { color: #00622f } /* Name.Entity */
+html[data-theme="light"] .highlight .ne { color: #6730c5 } /* Name.Exception */
+html[data-theme="light"] .highlight .nf { color: #005b82 } /* Name.Function */
+html[data-theme="light"] .highlight .nl { color: #7f4707 } /* Name.Label */
+html[data-theme="light"] .highlight .nn { color: #080808 } /* Name.Namespace */
+html[data-theme="light"] .highlight .nx { color: #080808 } /* Name.Other */
+html[data-theme="light"] .highlight .py { color: #005b82 } /* Name.Property */
+html[data-theme="light"] .highlight .nt { color: #005b82 } /* Name.Tag */
+html[data-theme="light"] .highlight .nv { color: #a12236 } /* Name.Variable */
+html[data-theme="light"] .highlight .ow { color: #6730c5 } /* Operator.Word */
+html[data-theme="light"] .highlight .pm { color: #080808 } /* Punctuation.Marker */
+html[data-theme="light"] .highlight .w { color: #080808 } /* Text.Whitespace */
+html[data-theme="light"] .highlight .mb { color: #7f4707 } /* Literal.Number.Bin */
+html[data-theme="light"] .highlight .mf { color: #7f4707 } /* Literal.Number.Float */
+html[data-theme="light"] .highlight .mh { color: #7f4707 } /* Literal.Number.Hex */
+html[data-theme="light"] .highlight .mi { color: #7f4707 } /* Literal.Number.Integer */
+html[data-theme="light"] .highlight .mo { color: #7f4707 } /* Literal.Number.Oct */
+html[data-theme="light"] .highlight .sa { color: #00622f } /* Literal.String.Affix */
+html[data-theme="light"] .highlight .sb { color: #00622f } /* Literal.String.Backtick */
+html[data-theme="light"] .highlight .sc { color: #00622f } /* Literal.String.Char */
+html[data-theme="light"] .highlight .dl { color: #00622f } /* Literal.String.Delimiter */
+html[data-theme="light"] .highlight .sd { color: #00622f } /* Literal.String.Doc */
+html[data-theme="light"] .highlight .s2 { color: #00622f } /* Literal.String.Double */
+html[data-theme="light"] .highlight .se { color: #00622f } /* Literal.String.Escape */
+html[data-theme="light"] .highlight .sh { color: #00622f } /* Literal.String.Heredoc */
+html[data-theme="light"] .highlight .si { color: #00622f } /* Literal.String.Interpol */
+html[data-theme="light"] .highlight .sx { color: #00622f } /* Literal.String.Other */
+html[data-theme="light"] .highlight .sr { color: #a12236 } /* Literal.String.Regex */
+html[data-theme="light"] .highlight .s1 { color: #00622f } /* Literal.String.Single */
+html[data-theme="light"] .highlight .ss { color: #005b82 } /* Literal.String.Symbol */
+html[data-theme="light"] .highlight .bp { color: #7f4707 } /* Name.Builtin.Pseudo */
+html[data-theme="light"] .highlight .fm { color: #005b82 } /* Name.Function.Magic */
+html[data-theme="light"] .highlight .vc { color: #a12236 } /* Name.Variable.Class */
+html[data-theme="light"] .highlight .vg { color: #a12236 } /* Name.Variable.Global */
+html[data-theme="light"] .highlight .vi { color: #a12236 } /* Name.Variable.Instance */
+html[data-theme="light"] .highlight .vm { color: #7f4707 } /* Name.Variable.Magic */
+html[data-theme="light"] .highlight .il { color: #7f4707 } /* Literal.Number.Integer.Long */
html[data-theme="dark"] .highlight pre { line-height: 125%; }
html[data-theme="dark"] .highlight td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
html[data-theme="dark"] .highlight span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
diff --git a/docs/about.html b/docs/about.html
index 8c96f9c07..f62066e2e 100644
--- a/docs/about.html
+++ b/docs/about.html
@@ -9,7 +9,7 @@
- About — Brevitas 0.10.2 documentation
+ About — Brevitas 0.11.0 documentation
@@ -125,8 +125,8 @@
-
-
+
+
diff --git a/docs/api_reference/brevitas.core.bit_width.html b/docs/api_reference/brevitas.core.bit_width.html
index c2a0a5090..226dccfa2 100644
--- a/docs/api_reference/brevitas.core.bit_width.html
+++ b/docs/api_reference/brevitas.core.bit_width.html
@@ -9,7 +9,7 @@
- brevitas.core.bit_width package — Brevitas 0.10.2 documentation
+ brevitas.core.bit_width package — Brevitas 0.11.0 documentation
@@ -126,8 +126,8 @@
-
-
+
+
@@ -447,11 +447,11 @@ Submodules
class brevitas.core.bit_width.const.BitWidthConst(bit_width, dtype=None, device=None)[source]#
-Bases: Module
+Bases: Module
ScriptModule that returns a constant bit-width wrapped in a float torch.tensor.
Examples
@@ -472,9 +472,9 @@ Submodules
forward()[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -489,12 +489,12 @@
Submodules
class brevitas.core.bit_width.const.BitWidthStatefulConst(bit_width, dtype=None, device=None)[source]#
-Bases: Module
+Bases: Module
ScriptModule that returns a constant bit-width wrapped in a float torch.tensor but retains the
bit-width as part of the module state.
Examples
@@ -517,9 +517,9 @@ Submodules
forward()[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -534,13 +534,13 @@
Submodules
class brevitas.core.bit_width.const.MsbClampBitWidth(bit_width_to_remove_impl, min_overall_bit_width, max_overall_bit_width)[source]#
-Bases: Module
+Bases: Module
-
forward(input_bit_width)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -558,15 +558,15 @@
Submodules
-
class brevitas.core.bit_width.parameter.BitWidthParameter(bit_width, min_bit_width=2, restrict_bit_width_impl=IntRestrictValue( (float_to_int_impl): RoundSte() ), override_pretrained_bit_width=False, dtype=None, device=None)[source]#
-Bases: Module
+Bases: Module
ScriptModule that returns a learnable bit-width wrapped in a float torch.Tensor.
- Parameters:
-bit_width (int) – value to initialize the output learned bit-width.
-min_bit_width (int) – lower bound for the output learned bit-width. Default: 2.
-restrict_bit_width_impl (Module
) – restrict the learned bit-width to a subset of values. Default: IntRestrictValue(RoundSte()).
-override_pretrained_bit_width (bool) – ignore pretrained bit-width loaded from a state dict. Default: False.
+bit_width (int) – value to initialize the output learned bit-width.
+min_bit_width (int) – lower bound for the output learned bit-width. Default: 2.
+restrict_bit_width_impl (Module
) – restrict the learned bit-width to a subset of values. Default: IntRestrictValue(RoundSte()).
+override_pretrained_bit_width (bool) – ignore pretrained bit-width loaded from a state dict. Default: False.
- Returns:
@@ -576,7 +576,7 @@ SubmodulesTensor
- Raises:
-RuntimeError – if bit_width < min_bit_width.
+RuntimeError – if bit_width < min_bit_width.
Examples
@@ -597,9 +597,9 @@ Submodules
-
forward()[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -614,13 +614,13 @@
Submodules
-
class brevitas.core.bit_width.parameter.RemoveBitwidthParameter(bit_width_to_remove, override_pretrained_bit_width=False, non_zero_epsilon=1e-06, remove_zero_bit_width=0.1, dtype=None, device=None)[source]#
-Bases: Module
+Bases: Module
-
forward()[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
diff --git a/docs/api_reference/brevitas.core.function_wrapper.html b/docs/api_reference/brevitas.core.function_wrapper.html
index 458e945a3..26217f590 100644
--- a/docs/api_reference/brevitas.core.function_wrapper.html
+++ b/docs/api_reference/brevitas.core.function_wrapper.html
@@ -9,7 +9,7 @@
-
brevitas.core.function_wrapper package — Brevitas 0.10.2 documentation
+ brevitas.core.function_wrapper package — Brevitas 0.11.0 documentation
@@ -126,8 +126,8 @@
-
-
+
+
@@ -448,7 +448,7 @@ Submodules
-
class brevitas.core.function_wrapper.clamp.ClampMin(min_val)[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for clamp_min()
.
Examples
>>> clamp_min = ClampMin(min_val=-2.0)
@@ -459,7 +459,7 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -472,11 +472,45 @@ Submodules
+-
+class brevitas.core.function_wrapper.clamp.FloatClamp(tensor_clamp_impl, signed, inf_values=None, nan_values=None, max_available_float=None, saturating=True, device=None, dtype=None)[source]#
+Bases: Module
+”
+ScriptModule for clamping minifloat formats to their inf/NaN implementations.
+Currently, inf/NaN codes have to be encoded through the mantissa.
+I.e. setting inf to 1101.111 (E4M3) is not a valid code.
+
+-
+forward(x, exponent_bit_width, mantissa_bit_width, exponent_bias)[source]#
+Define the computation performed at every call.
+Should be overridden by all subclasses.
+
+Note
+Although the recipe for forward pass needs to be defined within
+this function, one should call the Module
instance afterwards
+instead of this since the former takes care of running the
+registered hooks while the latter silently ignores them.
+
+
+
+
+
+
+
+
+
-
class brevitas.core.function_wrapper.clamp.ScalarClamp(min_val, max_val)[source]#
-Bases: Module
-ScriptModule wrapper for clamp()
.
+Bases: Module
+ScriptModule wrapper for clamp()
.
Examples
>>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0)
>>> scalar_clamp(torch.tensor([-3.0, 3.0]))
@@ -486,7 +520,7 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -502,7 +536,7 @@ Submodules
-
class brevitas.core.function_wrapper.clamp.TensorClamp[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for tensor_clamp()
.
Examples
>>> tensor_clamp = TensorClamp()
@@ -515,7 +549,7 @@ Submodules
-
forward(x, min_val, max_val)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -535,7 +569,7 @@ Submodules
-
class brevitas.core.function_wrapper.misc.Identity[source]#
-Bases: Module
+Bases: Module
Identity ScriptModule.
Examples
>>> identity = Identity()
@@ -548,9 +582,9 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -565,7 +599,7 @@
Submodules
-
class brevitas.core.function_wrapper.misc.InplaceLogTwo[source]#
-Bases: Module
+Bases: Module
Module wrapper for log2_()
.
Examples
>>> inplace_log_two = InplaceLogTwo()
@@ -580,9 +614,9 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -597,8 +631,8 @@
Submodules
-
class brevitas.core.function_wrapper.misc.LogTwo[source]#
-Bases: Module
-ScriptModule wrapper for log2()
.
+Bases: Module
+ScriptModule wrapper for log2()
.
Examples
>>> log_two = LogTwo()
>>> x = torch.tensor(8.0)
@@ -609,9 +643,9 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -626,7 +660,7 @@
Submodules
-
class brevitas.core.function_wrapper.misc.PowerOfTwo[source]#
-Bases: Module
+Bases: Module
ScriptModule implementation of 2.0 ** x.
Examples
>>> power_of_two = PowerOfTwo()
@@ -638,9 +672,9 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
-:rtype: Tensor
+:rtype: Tensor
Note
Although the recipe for forward pass needs to be defined within
@@ -659,12 +693,12 @@
Submodules
-
class brevitas.core.function_wrapper.ops_ste.CeilSte[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for ceil_ste()
.
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -680,12 +714,12 @@ Submodules
-
class brevitas.core.function_wrapper.ops_ste.DPURoundSte[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for dpu_round_ste()
.
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -701,12 +735,12 @@ Submodules
-
class brevitas.core.function_wrapper.ops_ste.FloorSte[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for floor_ste()
.
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -722,12 +756,12 @@ Submodules
-
class brevitas.core.function_wrapper.ops_ste.InplaceTensorClampSte[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for tensor_clamp_ste_()
.
-
forward(x, min_val, max_val)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -743,12 +777,12 @@ Submodules
-
class brevitas.core.function_wrapper.ops_ste.RoundSte[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for round_ste()
.
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -764,12 +798,12 @@ Submodules
-
class brevitas.core.function_wrapper.ops_ste.RoundToZeroSte[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for round_to_zero_ste()
.
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -785,12 +819,12 @@ Submodules
-
class brevitas.core.function_wrapper.ops_ste.ScalarClampMinSte(min_val)[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for scalar_clamp_min_ste()
.
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -806,12 +840,12 @@ Submodules
-
class brevitas.core.function_wrapper.ops_ste.TensorClampSte[source]#
-Bases: Module
+Bases: Module
ScriptModule wrapper for tensor_clamp_ste()
.
-
forward(x, min_val, max_val)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -828,10 +862,30 @@ Submodules
brevitas.core.function_wrapper.shape module#
ScriptModule classes to compute the view of a tensor according to various different criteria.
+
+-
+class brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView(group_size, group_dim)[source]#
+Bases: Module
+
+-
+forward(x)[source]#
+Define the computation performed at every call.
+Should be overridden by all subclasses.
+
+Note
+Although the recipe for forward pass needs to be defined within
+this function, one should call the Module
instance afterwards
+instead of this since the former takes care of running the
+registered hooks while the latter silently ignores them.
+
+
+
+
+
-
class brevitas.core.function_wrapper.shape.OverBatchOverOutputChannelView(permute_dims=None)[source]#
-Bases: Module
+Bases: Module
ScriptModule to compute the over_batch_over_output_channels()
view of an input tensor.
Examples
@@ -844,7 +898,7 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -860,7 +914,7 @@ Submodules
-
class brevitas.core.function_wrapper.shape.OverBatchOverTensorView(permute_dims=None)[source]#
-Bases: Module
+Bases: Module
ScriptMoodule to compute the over_batch_over_tensor()
view of an
input tensor.
Examples
@@ -873,7 +927,7 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -889,7 +943,7 @@ Submodules
-
class brevitas.core.function_wrapper.shape.OverOutputChannelView(permute_dims=None)[source]#
-Bases: Module
+Bases: Module
ScriptMoodule to compute the over_output_channels()
view of an
input tensor.
Examples
@@ -902,7 +956,7 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
Should be overridden by all subclasses.
Note
@@ -918,7 +972,7 @@ Submodules
-
class brevitas.core.function_wrapper.shape.OverOutputFeaturesView(permute_dims=None)[source]#
-Bases: Module
+Bases: Module
ScriptModule to compute the over_output_features()
view of an input tensor.
Examples
@@ -931,7 +985,27 @@ Submodules
-
forward(x)[source]#
-Defines the computation performed at every call.
+Define the computation performed at every call.
+Should be overridden by all subclasses.
+
+Note
+Although the recipe for forward pass needs to be defined within
+this function, one should call the Module
instance afterwards
+instead of this since the former takes care of running the
+registered hooks while the latter silently ignores them.
+
+
+
+
+
+