diff --git a/.github/workflows/base.yml.template b/.github/workflows/base.yml.template index bf296e597..465cf1c41 100644 --- a/.github/workflows/base.yml.template +++ b/.github/workflows/base.yml.template @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/base_reduced.yml.template b/.github/workflows/base_reduced.yml.template index c50903499..46b916895 100644 --- a/.github/workflows/base_reduced.yml.template +++ b/.github/workflows/base_reduced.yml.template @@ -22,7 +22,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/develop_install.yml b/.github/workflows/develop_install.yml index cea916825..bdc0df76b 100644 --- a/.github/workflows/develop_install.yml +++ b/.github/workflows/develop_install.yml @@ -30,7 +30,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/end_to_end.yml b/.github/workflows/end_to_end.yml index dba8a2911..a83ba8899 100644 --- a/.github/workflows/end_to_end.yml +++ b/.github/workflows/end_to_end.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/examples_llm_pytest.yml b/.github/workflows/examples_llm_pytest.yml index e939a93b2..065a06738 100644 --- a/.github/workflows/examples_llm_pytest.yml +++ b/.github/workflows/examples_llm_pytest.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/examples_pytest.yml b/.github/workflows/examples_pytest.yml index d514c6c33..262625e01 100644 --- a/.github/workflows/examples_pytest.yml +++ b/.github/workflows/examples_pytest.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/finn_integration.yml b/.github/workflows/finn_integration.yml index cb1946b84..f62876a45 100644 --- a/.github/workflows/finn_integration.yml +++ b/.github/workflows/finn_integration.yml @@ -30,7 +30,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/notebook.yml b/.github/workflows/notebook.yml index 7678a9d15..4e8c06e78 100644 --- a/.github/workflows/notebook.yml +++ b/.github/workflows/notebook.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/ort_integration.yml b/.github/workflows/ort_integration.yml index 519873f05..02c75fcf2 100644 --- a/.github/workflows/ort_integration.yml +++ b/.github/workflows/ort_integration.yml @@ -30,7 +30,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index c7e407799..4505d8df0 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_develop_install.yml b/.github/workflows/reduced_develop_install.yml index b23fa97b6..b5ead15ee 100644 --- a/.github/workflows/reduced_develop_install.yml +++ b/.github/workflows/reduced_develop_install.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_end_to_end.yml b/.github/workflows/reduced_end_to_end.yml index f06c29872..c52fb0ceb 100644 --- a/.github/workflows/reduced_end_to_end.yml +++ b/.github/workflows/reduced_end_to_end.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_examples_llm_pytest.yml b/.github/workflows/reduced_examples_llm_pytest.yml index b9c3deffe..44b0de612 100644 --- a/.github/workflows/reduced_examples_llm_pytest.yml +++ b/.github/workflows/reduced_examples_llm_pytest.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_examples_pytest.yml b/.github/workflows/reduced_examples_pytest.yml index 62541236f..b4d42540c 100644 --- a/.github/workflows/reduced_examples_pytest.yml +++ b/.github/workflows/reduced_examples_pytest.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_finn_integration.yml b/.github/workflows/reduced_finn_integration.yml index b4e0e62d1..342a01c34 100644 --- a/.github/workflows/reduced_finn_integration.yml +++ b/.github/workflows/reduced_finn_integration.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_notebook.yml b/.github/workflows/reduced_notebook.yml index 159d4a9a4..5d2dc1b6f 100644 --- a/.github/workflows/reduced_notebook.yml +++ b/.github/workflows/reduced_notebook.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_ort_integration.yml b/.github/workflows/reduced_ort_integration.yml index 06219f128..9fcb678d2 100644 --- a/.github/workflows/reduced_ort_integration.yml +++ b/.github/workflows/reduced_ort_integration.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_pytest.yml b/.github/workflows/reduced_pytest.yml index 8af119e15..f3d0763e3 100644 --- a/.github/workflows/reduced_pytest.yml +++ b/.github/workflows/reduced_pytest.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/src/brevitas/config.py b/src/brevitas/config.py index a5685721c..082a6508b 100644 --- a/src/brevitas/config.py +++ b/src/brevitas/config.py @@ -25,3 +25,4 @@ def env_to_bool(name, default): _FULL_STATE_DICT = False _IS_INSIDE_QUANT_LAYER = None _ONGOING_EXPORT = None +_RETROCOMPATIBLE_SCALING = False diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index f4fd79f1a..145f5ca06 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -67,12 +67,13 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scale / float_scaling_impl_value + else: + float_scaling_impl_value = None + scale = self.scaling_impl(x, float_scaling_impl_value) x = self.input_view_impl(x) scaled_x = x / scale internal_scale = float_internal_scale( diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..a9bcf6e6b 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -8,6 +8,7 @@ from torch.nn import Module import brevitas +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import round_ste @@ -138,20 +139,24 @@ def __init__( scaling_impl: Module, int_scaling_impl: Module, zero_point_impl: Module, - bit_width_impl: Module): + bit_width_impl: Module, + scaling_int_quant: Optional[Module] = None): super(RescalingIntQuant, self).__init__() self.int_quant = int_quant self.scaling_impl = scaling_impl self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + if scaling_int_quant is None: + self.scaling_int_quant = Identity() + else: + self.scaling_int_quant = scaling_int_quant @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: bit_width = self.msb_clamp_bit_width_impl() - threshold = self.scaling_impl(x) int_threshold = self.int_scaling_impl(bit_width) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -184,8 +189,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te pre_threshold = self.pre_scaling_impl(x) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - threshold = self.scaling_impl(x) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point @@ -250,8 +254,7 @@ def forward(self, x: Tensor, input_bit_width: Tensor, pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - threshold = self.scaling_impl(x) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 449318765..97a957fc6 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -90,6 +90,9 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() + def retrocompatibility_op(self, x): + return x + @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> Tensor: return x @@ -113,6 +116,9 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() + def retrocompatibility_op(self, x): + return self.power_of_two(x) + @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.power_of_two(x) @@ -137,6 +143,9 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() + def retrocompatibility_op(self, x): + return x + @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.float_to_int_impl(x) @@ -162,8 +171,38 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() + def retrocompatibility_op(self, x): + return self.power_of_two(x) + @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x + + +class QuantRestrictValue(brevitas.jit.ScriptModule): + + def __init__(self, restrict_value_float_to_int_impl: Module): + super(QuantRestrictValue, self).__init__() + self.float_to_int_impl = restrict_value_float_to_int_impl + + def restrict_init_float(self, x: float): + return Identity() + + def restrict_init_tensor(self, x: torch.Tensor): + return Identity() + + def restrict_init_module(self): + return Identity() + + def restrict_init_inplace_module(self): + return Identity() + + def retrocompatibility_op(self, x): + return Identity() + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor): + o, *_ = self.float_to_int_impl(x) + return o diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index e4333186d..0a49b3c70 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -50,10 +50,11 @@ def __init__( dtype, device) - @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: - stats = self.parameter_list_stats() - return self.stats_scaling_impl(stats) + def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + stats = self.parameter_list_stats(x) + if threshold is None: + threshold = torch.ones(1).type_as(stats) + return self.stats_scaling_impl(stats, threshold) class _StatsScaling(brevitas.jit.ScriptModule): @@ -80,8 +81,11 @@ def __init__( self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() @brevitas.jit.script_method - def forward(self, stats: torch.Tensor) -> torch.Tensor: - stats = self.restrict_scaling_pre(stats) + def forward( + self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats) + stats = self.restrict_scaling_pre(stats / threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats @@ -120,9 +124,9 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.runtime_stats(x) - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) class _AffineRescaling(brevitas.jit.ScriptModule): @@ -179,9 +183,14 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, stats_input) -> torch.Tensor: + def forward( + self, + stats_input: torch.Tensor, + threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) + out = self.scaling_stats_impl(stats_input_reshaped) / threshold # Scaling min val out = self.restrict_clamp_scaling(out) return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 53f389331..bdf838749 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -77,8 +77,10 @@ def __init__( self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = self.value() + def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + value = self.value() / threshold restricted_value = self.restrict_clamp_scaling(value) return restricted_value @@ -149,8 +151,10 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) + def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value) / threshold) return value def _load_from_state_dict( @@ -190,29 +194,39 @@ def __init__( scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) + self.restrict_scaling_impl = restrict_scaling_impl self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) if restrict_scaling_impl is not None: self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() else: self.restrict_inplace_preprocess = Identity() + self.restrict_preprocess = Identity() + self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(x) + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.init_done: - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = self.restrict_preprocess(self.value / threshold) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) return value else: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: - return self.stats_scaling_impl(stats) - stats = self.restrict_inplace_preprocess(stats) + return self.stats_scaling_impl(stats, threshold) inplace_tensor_mul(self.value.detach(), stats) + value = self.restrict_preprocess(self.value / threshold) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) self.init_done = True return value @@ -228,9 +242,18 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + value_key = prefix + 'value' + + # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) + # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) + # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) + if config._RETROCOMPATIBLE_SCALING: + if not isinstance(self.restrict_scaling_impl, Identity): + state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op( + state_dict[value_key]) + super(ParameterFromStatsFromParameterScaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - value_key = prefix + 'value' # disable stats collection when a pretrained value is loaded if value_key not in missing_keys: self.init_done = True @@ -305,6 +328,7 @@ def __init__( scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) + self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( @@ -317,7 +341,10 @@ def __init__( self.restrict_preprocess = Identity() @brevitas.jit.script_method - def training_forward(self, stats_input: Tensor) -> Tensor: + def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -327,32 +354,37 @@ def training_forward(self, stats_input: Tensor) -> Tensor: new_counter = self.counter + 1 # Whenever we are in local loss mode, we don't update the counter nor the buffer if self.local_loss_mode: - return abs_binary_sign_grad(clamped_stats) + # Local loss mode, we early exit and divide by threshold + return abs_binary_sign_grad(clamped_stats / threshold) if self.counter == 0: inplace_tensor_mul(self.buffer, clamped_stats.detach()) else: inplace_momentum_update( self.buffer, clamped_stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter - return abs_binary_sign_grad(clamped_stats) + return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: - self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) + value = self.restrict_preprocess(self.value / threshold) self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + value = self.restrict_preprocess(self.value / threshold) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) @brevitas.jit.script_method - def forward(self, stats_input: Tensor) -> Tensor: + def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) if self.training: - return self.training_forward(stats_input) + # Threshold division handled inside the training_forward + return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer + out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.value + out = self.restrict_preprocess(self.value / threshold) out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out @@ -378,6 +410,14 @@ def _load_from_state_dict( if retrocomp_value_key in state_dict: state_dict[value_key] = state_dict.pop(retrocomp_value_key) + # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) + # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) + # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) + if config._RETROCOMPATIBLE_SCALING: + if not isinstance(self.restrict_scaling_impl, Identity): + state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op( + state_dict[value_key]) + super(ParameterFromRuntimeStatsScaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # Buffer is supposed to be always missing diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index df3cec952..49bf62a82 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -13,6 +13,7 @@ from brevitas.core.utils import inplace_tensor_mul from .view_wrapper import _ViewCatParameterWrapper +from .view_wrapper import _ViewParameter from .view_wrapper import _ViewParameterWrapper DEFAULT_MOMENTUM = 0.1 @@ -96,8 +97,12 @@ def __init__( super(_ParameterListStats, self).__init__() self.stats_input_concat_dim = stats_input_concat_dim - self.first_tracked_param = _ViewParameterWrapper( - tracked_parameter_list[0], stats_input_view_shape_impl) + if len(tracked_parameter_list) >= 1: + self.first_tracked_param = _ViewParameterWrapper( + tracked_parameter_list[0], stats_input_view_shape_impl) + else: + self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) + if len(tracked_parameter_list) > 1: extra_list = [ _ViewCatParameterWrapper( @@ -109,10 +114,12 @@ def __init__( self.stats = _Stats(stats_impl, stats_output_shape) @brevitas.jit.script_method - def forward(self) -> torch.Tensor: - stats_input = self.first_tracked_param() + def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor: if self.extra_tracked_params_list is not None: + stats_input = self.first_tracked_param(None) for extra_tracked_param in self.extra_tracked_params_list: stats_input = extra_tracked_param(stats_input) + else: + stats_input = self.first_tracked_param(x) out = self.stats(stats_input) return out diff --git a/src/brevitas/core/stats/view_wrapper.py b/src/brevitas/core/stats/view_wrapper.py index acea542d9..98c6ab538 100644 --- a/src/brevitas/core/stats/view_wrapper.py +++ b/src/brevitas/core/stats/view_wrapper.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + import torch from torch import Tensor from torch.nn import Module @@ -19,8 +21,12 @@ def __init__(self, parameter: Parameter, view_shape_impl: Module) -> None: self.view_shape_impl = view_shape_impl @brevitas.jit.script_method - def forward(self) -> Tensor: - return self.view_shape_impl(self.parameter) + def forward(self, x: Optional[Tensor]) -> Tensor: + if x is not None: + parameter = x + else: + parameter = self.parameter + return self.view_shape_impl(parameter) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -39,6 +45,17 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): return output_dict +class _ViewParameter(brevitas.jit.ScriptModule): + + def __init__(self, view_shape_impl: Module) -> None: + super(_ViewParameter, self).__init__() + self.view_shape_impl = view_shape_impl + + @brevitas.jit.script_method + def forward(self, x: Tensor) -> Tensor: + return self.view_shape_impl(x) + + class _ViewCatParameterWrapper(brevitas.jit.ScriptModule): __constants__ = ['cat_dim'] diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 3f80f1dd4..499de376f 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -60,6 +60,20 @@ def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tenso return out +class _ScaleShiftQuantZeroPoint(brevitas.jit.ScriptModule): + __constants__ = ['quantize_zero_point'] + + def __init__(self, zp_int_quant: Module, quantize_zero_point: bool) -> None: + super(_ScaleShiftQuantZeroPoint, self).__init__() + self.zp_int_quant = zp_int_quant + self.quantize_zero_point = quantize_zero_point + + @brevitas.jit.script_method + def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: + quant_zp, scale, *_ = self.zp_int_quant(zero_point) + return quant_zp + + class StatsFromParameterZeroPoint(brevitas.jit.ScriptModule): def __init__( @@ -70,7 +84,8 @@ def __init__( zero_point_stats_input_concat_dim: int, zero_point_stats_impl: Module, zero_point_shape: Tuple[int, ...], - tracked_parameter_list: List[torch.nn.Parameter]) -> None: + tracked_parameter_list: List[torch.nn.Parameter], + scale_shit_zero_point_impl: Optional[Module] = None) -> None: super(StatsFromParameterZeroPoint, self).__init__() self.parameter_list_stats = _ParameterListStats( zero_point_stats_impl, @@ -78,11 +93,14 @@ def __init__( zero_point_stats_input_view_shape_impl, zero_point_stats_input_concat_dim, tracked_parameter_list) - self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + if scale_shit_zero_point_impl is None: + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + else: + self.scale_shift_zero_point = scale_shit_zero_point_impl @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) return self.scale_shift_zero_point(-stats, scale, bit_width) @@ -266,7 +284,7 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: value = self.scale_shift_zero_point(value, scale, bit_width) return value else: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 93cc235e2..776f1f6b2 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -25,10 +25,10 @@ def __init__( self.stats_impl = scaling_stats_impl self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn - def forward(self, x) -> Tensor: + def forward(self, x, threshold) -> Tensor: shape = x.shape x = self.scaling_stats_input_view_shape_impl(x) - x = self.stats_impl(x) + x = self.stats_impl(x) / threshold x = self.dynamic_scaling_broadcastable_fn(x, shape) return x diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e19390774..250b7755f 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -268,6 +268,8 @@ def main(args): model = offload_model(model) + model(**calibration_loader[0]) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 16b8a4b5f..a5f597586 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -109,9 +109,10 @@ def test_float_to_quant_float(inp, minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): + float_scaling_impl_return = 1. bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format - scaling_impl = mock.Mock(side_effect=lambda x: 1.) - float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) + scaling_impl = mock.Mock(side_effect=lambda x, y: 1.) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: float_scaling_impl_return) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( @@ -148,7 +149,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): torch.tensor(exponent_bit_width), torch.tensor(mantissa_bit_width), torch.tensor(exponent_bias)) - scaling_impl.assert_called_once_with(inp) + scaling_impl.assert_called_once_with(inp, float_scaling_impl_return) @given( @@ -160,7 +161,7 @@ def test_inner_scale(inp, minifloat_format, scale): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) - scaling_impl = mock.Mock(side_effect=lambda x: scale) + scaling_impl = mock.Mock(side_effect=lambda x, y: scale) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( diff --git a/tests/brevitas/core/test_scaling_quant.py b/tests/brevitas/core/test_scaling_quant.py new file mode 100644 index 000000000..dab8312e9 --- /dev/null +++ b/tests/brevitas/core/test_scaling_quant.py @@ -0,0 +1,127 @@ +from dependencies import this +from dependencies import value +import torch + +from brevitas.core.quant.int import RescalingIntQuant +from brevitas.core.restrict_val import QuantRestrictValue +from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +from brevitas.inject.enum import ScalingPerOutputType +import brevitas.nn as qnn +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat + + +class QuantScalingInt(Int8WeightPerTensorFloat): + bit_width = 8 + module = (this << 1).module + tracked_parameter_list = (this << 1).tracked_parameter_list + upstream_scaling = (this << 1).scaling_per_output_type + rescaling_int_quant = RescalingIntQuant + + @value + def scaling_shape( + scaling_per_output, + scaling_per_output_channel_shape, + expanded_groupwise_shape, + group_dim, + upstream_scaling): + if scaling_per_output == ScalingPerOutputType.TENSOR: + scaling = SCALAR_SHAPE + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + scaling = scaling_per_output_channel_shape + elif scaling_per_output == ScalingPerOutputType.GROUP: + # Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1 + assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured" + assert group_dim is not None, "Per Group scaling not correctly configured" + size = list(expanded_groupwise_shape) + size[group_dim + 1] = 1 + scaling = tuple(size) + + # When quantizing scale of groupwise, there will be one extra dim compared to the normal case + if upstream_scaling == ScalingPerOutputType.GROUP: + scaling = list(scaling) + scaling.insert(-1, 1) + scaling = tuple(scaling) + return scaling + + +from brevitas.core.zero_point import _ScaleShiftQuantZeroPoint + + +class QuantZPInt(Int8WeightPerTensorFloat): + bit_width = 8 + module = (this << 1).module + tracked_parameter_list = (this << 1).tracked_parameter_list + upstream_scaling = (this << 1).scaling_per_output_type + rescaling_int_quant = RescalingIntQuant + bit_width = 6 + quantize_zero_point = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL + + @value + def scaling_shape( + scaling_per_output, + scaling_per_output_channel_shape, + expanded_groupwise_shape, + group_dim, + upstream_scaling): + if scaling_per_output == ScalingPerOutputType.TENSOR: + scaling = SCALAR_SHAPE + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + scaling = scaling_per_output_channel_shape + elif scaling_per_output == ScalingPerOutputType.GROUP: + # Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1 + assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured" + assert group_dim is not None, "Per Group scaling not correctly configured" + size = list(expanded_groupwise_shape) + size[group_dim + 1] = 1 + scaling = tuple(size) + + # When quantizing scale of groupwise, there will be one extra dim compared to the normal case + if upstream_scaling == ScalingPerOutputType.GROUP: + scaling = list(scaling) + scaling.insert(-1, 1) + scaling = tuple(scaling) + return scaling + + +class QuantScaleInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat): + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_int_quant = QuantScalingInt + zp_int = QuantZPInt + restrict_scaling_impl = QuantRestrictValue + scaling_per_output_type = ScalingPerOutputType.GROUP + scale_shit_zero_point_impl = _ScaleShiftQuantZeroPoint + group_size = 32 + + @value + def restrict_value_float_to_int_impl(): + return this.scaling_int_quant.rescaling_int_quant + + @value + def zp_int_quant(): + return this.zp_int.rescaling_int_quant + + +def test_quant_scale(): + + def hook_scale(module, inp): + inp = inp[0] + quant_scale, scale, *_ = module.float_to_int_impl(inp) + assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + + def hook_zp(module, inp): + inp = inp[0] + quant_scale, scale, *_ = module.zp_int_quant(inp) + assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + + linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleInt8WeightPerTensorFloat) + for module in linear.modules(): + if isinstance(module, QuantRestrictValue): + module.register_forward_pre_hook(hook_scale) + for module in linear.modules(): + if isinstance(module, _ScaleShiftQuantZeroPoint): + module.register_forward_pre_hook(hook_zp) + + linear(torch.randn(1, 64)) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index b22994275..10d8f7e7c 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -12,6 +12,7 @@ from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import load_quant_model_mode +from brevitas.inject.enum import RestrictValueType import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat @@ -27,7 +28,9 @@ BATCH = 1 REFERENCE_SCALES = { 'int_quant': (0.00935234408825635910, 0.01362917013466358185), - 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} + 'fp_quant': (0.00249395845457911491, 0.00363444536924362183), + 'int_po2_quant': (0.015625, 0.015625), + 'fp_po2_quant': (0.001953125, 0.00390625),} REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], [1.4573, -0.9074, -0.2708]]) @@ -44,9 +47,9 @@ def reference_implementation_scale_factors_po2( quant = compute_quantile(x, q) quant = torch.max(min_val, quant) quant_float_to_int = torch.ceil( - torch.log2(quant)) # Float to Int Implementation for PowerOfTwo scale + torch.log2(quant / int_scale)) # Float to Int Implementation for PowerOfTwo scale - scale = torch.pow(torch.tensor(2.), quant_float_to_int) / int_scale + scale = torch.pow(torch.tensor(2.), quant_float_to_int) return scale @@ -75,7 +78,15 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) -QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} +class Fp8e4m3ActPerTensorFixedPoint(Fp8e4m3ActPerTensorFloat): + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + + +QUANTS = { + 'int_quant': Int8ActPerTensorFloat, + 'fp_quant': Fp8e4m3ActPerTensorFloat, + 'int_po2_quant': Int8ActPerTensorFixedPoint, + 'fp_po2_quant': Fp8e4m3ActPerTensorFixedPoint} @pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) diff --git a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py index dded0e276..c72a2d1f8 100644 --- a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py @@ -12,9 +12,11 @@ from qonnx.transformation.infer_shapes import InferShapes import torch +import brevitas.config as config from brevitas.export import export_qonnx from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b +config._RETROCOMPATIBLE_SCALING = True QUARTZNET_POSTPROCESSED_INPUT_SIZE = (1, 64, 256) # B, features, sequence MIN_INP_VAL = 0.0 MAX_INP_VAL = 200.0