diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index aa9a4fa2e..59b3fe8ec 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -36,7 +36,7 @@ def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Option self.restrict_value_impl = Identity() @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.restrict_value_impl(x) x = self.clamp_min_ste(x) return x @@ -52,7 +52,7 @@ def __init__(self, restrict_value_impl: Optional[Module]): self.restrict_value_impl = Identity() @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.restrict_value_impl(x) return x @@ -68,7 +68,7 @@ def __init__(self, scaling_min_val: Optional[float]): self.min_val = scaling_min_val @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.clamp_min_ste(x) return x @@ -90,11 +90,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_stats_threshold(self, x, threshold): + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: return x / threshold @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: return x @@ -107,7 +107,7 @@ def __init__(self): def restrict_init_float(self, x: float): return math.log2(x) - def restrict_init_tensor(self, x: torch.Tensor): + def restrict_init_tensor(self, x: Tensor): return torch.log2(x) def restrict_init_module(self): @@ -116,11 +116,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_stats_threshold(self, x, threshold): + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: return x - threshold @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.power_of_two(x) return x @@ -134,7 +134,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()): def restrict_init_float(self, x: float): return x - def restrict_init_tensor(self, x: torch.Tensor): + def restrict_init_tensor(self, x: Tensor): return x def restrict_init_module(self): @@ -143,11 +143,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_stats_threshold(self, x, threshold): + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: return x / threshold @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.float_to_int_impl(x) return x @@ -162,7 +162,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()): def restrict_init_float(self, x: float): return math.log2(x) - def restrict_init_tensor(self, x: torch.Tensor): + def restrict_init_tensor(self, x: Tensor): return torch.log2(x) def restrict_init_module(self): @@ -171,11 +171,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_stats_threshold(self, x, threshold): + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: return x - threshold @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 3b50142d9..f11eb1f2a 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -11,6 +11,7 @@ import brevitas.config as config from brevitas.core.function_wrapper import Identity from brevitas.core.restrict_val import _RestrictClampValue +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.stats import _ParameterListStats from brevitas.core.stats import _RuntimeStats from brevitas.core.stats import DEFAULT_MOMENTUM @@ -27,8 +28,8 @@ def __init__( scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], + restrict_scaling_impl: Module = FloatRestrictValue(), affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, @@ -90,7 +91,7 @@ def forward( threshold = torch.ones(1).type_as(stats) threshold = self.restrict_scaling_pre(threshold) stats = self.restrict_scaling_pre(stats) - stats = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold) + stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats @@ -102,10 +103,10 @@ def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], affine_rescaling: bool = False, affine_shift_scale: bool = False, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -172,13 +173,13 @@ def _load_from_state_dict( class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: torch.nn.Module, - scaling_stats_impl: torch.nn.Module, - scaling_min_val: Optional[float], - restrict_scaling_impl: Optional[torch.nn.Module]) -> None: + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + scaling_stats_impl: Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Module = FloatRestrictValue()) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() self.group_size = group_size self.group_dim = group_dim diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 943f1d69c..c267da589 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -15,6 +15,7 @@ from brevitas.core.restrict_val import _ClampValue from brevitas.core.restrict_val import _RestrictClampValue from brevitas.core.restrict_val import _RestrictValue +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.scaling.runtime import _StatsScaling from brevitas.core.stats import _ParameterListStats from brevitas.core.stats import _Stats @@ -83,7 +84,7 @@ def __init__( self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling @@ -163,7 +164,7 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling @@ -197,8 +198,8 @@ def __init__( scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -214,26 +215,20 @@ def __init__( restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) - if restrict_scaling_impl is not None: - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() - else: - self.restrict_inplace_preprocess = Identity() - self.restrict_preprocess = Identity() - + self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward( - self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(ignored) # Threshold division must happen after we update self.value, but before we apply restrict_preproces # This is because we don't want to store a parameter dependant on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold if self.init_done: - value = self.restrict_scaling_impl.combine_stats_threshold( - self.value, self.restrict_inplace_preprocess(threshold)) + threshold = self.restrict_inplace_preprocess(threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(value, threshold) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) return value else: @@ -245,7 +240,7 @@ def forward( stats = self.restrict_inplace_preprocess(stats) threshold = self.restrict_inplace_preprocess(threshold) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(value, threshold) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) self.init_done = True return value @@ -323,7 +318,7 @@ def __init__( scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -343,15 +338,11 @@ def __init__( self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( False, bool) # required to support MSE eval or variants - if restrict_scaling_impl is not None: - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() - else: - self.restrict_inplace_preprocess = Identity() - self.restrict_preprocess = Identity() + self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() @brevitas.jit.script_method - def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: + def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: # Threshold division must happen after we update self.value, but before we apply restrict_preproces # This is because we don't want to store a parameter dependent on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold @@ -377,16 +368,16 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) self.counter = self.counter + 1 return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) @brevitas.jit.script_method - def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats_input) if self.training: @@ -398,7 +389,7 @@ def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) out = self.restrict_preprocess(out) else: threshold = self.restrict_preprocess(threshold) - out = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold) + out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out