From 0920f92a53229f2e40944566e16d39b4e468d9b6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 17 Sep 2024 21:19:27 +0000 Subject: [PATCH] remove function --- .../quantization/lifecycle/helpers.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/helpers.py b/src/compressed_tensors/quantization/lifecycle/helpers.py index 497a9921..9d755328 100644 --- a/src/compressed_tensors/quantization/lifecycle/helpers.py +++ b/src/compressed_tensors/quantization/lifecycle/helpers.py @@ -16,62 +16,15 @@ Miscelaneous helpers for the quantization lifecycle """ -from typing import Optional - -import torch from torch.nn import Module __all__ = [ - "update_layer_weight_quant_params", "enable_quantization", "disable_quantization", ] -def update_layer_weight_quant_params( - layer: Module, - weight: Optional[torch.Tensor] = None, - g_idx: Optional[torch.Tensor] = None, - reset_obs: bool = False, -): - """ - Update quantization parameters on layer - - :param layer: input layer - :param weight: weight to update quant params with, defaults to layer weight - :param g_idx: optional mapping from column index to group index - :param reset_obs: reset the observer before calculating quant params, - defaults to False - """ - attached_weight = getattr(layer, "weight", None) - - if weight is None: - weight = attached_weight - scale = getattr(layer, "weight_scale", None) - zero_point = getattr(layer, "weight_zero_point", None) - if g_idx is None: - g_idx = getattr(layer, "weight_g_idx", None) - observer = getattr(layer, "weight_observer", None) - - if weight is None or observer is None or scale is None or zero_point is None: - # scale, zp, or observer not calibratable or weight not available - return - - if reset_obs: - observer.reset() - - if attached_weight is not None: - weight = weight.to(attached_weight.dtype) - - updated_scale, updated_zero_point = observer(weight) - - # update scale and zero point - device = next(layer.parameters()).device - scale.data = updated_scale.to(device) - zero_point.data = updated_zero_point.to(device) - - def enable_quantization(module: Module): module.quantization_enabled = True