Skip to content

Commit

Permalink
Something works
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Oct 21, 2024
1 parent c98880a commit 2cbbb01
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 49 deletions.
11 changes: 3 additions & 8 deletions nncf/openvino/quantization/compression_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_compress_decompress_weight_primitive(
zero_point_shape,
)


def _build_compress_decompress_model(
config: WeightCompressionConfig,
params: PrimitiveParameters,
Expand All @@ -131,13 +132,7 @@ def _build_compress_decompress_model(
zero_point_shape: Optional[Tuple] = None,
):
ov_parameters, ov_results = _build_compress_model(
config,
params,
weight_shape,
scale_shape,
zero_point_shape,
reduction_axes=None,
return_nodes=True
config, params, weight_shape, scale_shape, zero_point_shape, reduction_axes=None, return_nodes=True
)
return _get_compress_decompress_model(
config,
Expand Down Expand Up @@ -196,7 +191,7 @@ def _build_compress_model(

num_groups_per_channel = channel_size // group_size
shape = list(weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis
shape[reduction_axes: reduction_axes + 1] = (num_groups_per_channel, group_size)
shape[reduction_axes : reduction_axes + 1] = (num_groups_per_channel, group_size)
weight = opset.reshape(weight, shape, special_zero=False)
reduction_axes += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@
# limitations under the License.


from .common import reshape_weight_for_grouped_quantization, calculate_nf4_scale, do_nf4_quantization, \
do_nf4_dequantization, calculate_normalized_weight_and_fp4_scale, calculate_integer_quantization_params, \
calculate_quantized_weight, compress_weight, do_int_dequantization

from .dispatched_functions import do_int_quantization, calculate_quantized_dequantized_weight
from .common import WeightCompressionConfig
from .common import calculate_integer_quantization_params
from .common import calculate_nf4_scale
from .common import calculate_normalized_weight_and_fp4_scale
from .common import calculate_quantized_weight
from .common import compress_weight
from .common import do_int_dequantization
from .common import do_nf4_dequantization
from .common import do_nf4_quantization
from .common import get_integer_quantization_error
from .common import reshape_weight_for_grouped_quantization
from .dispatched_functions import calculate_quantized_dequantized_weight
from .dispatched_functions import do_int_quantization
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from .dispatched_functions import do_int_quantization


ReductionAxes = Tuple[int, ...]

NF4_QUANTILES = np.array(
Expand Down Expand Up @@ -342,7 +341,9 @@ def get_integer_quantization_error(
if weight.dtype != TensorDataType.float32:
weight = weight.astype(TensorDataType.float32)

compressed_weights, scale, zero_point = do_int_quantization(weight, reduction_axes, config, invert_division=invert_division)
compressed_weights, scale, zero_point = do_int_quantization(
weight, reduction_axes, config, invert_division=invert_division
)
decompressed_weight = do_int_dequantization(compressed_weights, scale, zero_point)

decompressed_weight = decompressed_weight.reshape(orig_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Optional
from typing import Optional, Tuple

from nncf.tensor import Tensor
from .weight_lowering_dispatcher import weight_lowering_dispatcher, ov_available_backend_selector, BackendParametersContainer

from ..config import WeightCompressionConfig
from .weight_lowering_dispatcher import ov_available_backend_selector
from .weight_lowering_dispatcher import weight_lowering_dispatcher


@weight_lowering_dispatcher(ov_available_backend_selector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.tensor import Tensor

from .dispatched_functions import calculate_quantized_dequantized_weight
from .dispatched_functions import do_int_quantization
from .weight_lowering_dispatcher import WeightLoweringBackend
from .dispatched_functions import do_int_quantization, calculate_quantized_dequantized_weight


@dataclass
Expand All @@ -46,11 +47,14 @@ class OVModelParameters:
share_outputs: bool = True
input_dtype: str = "fp32"

def __hash__(self):
return hash((self.dynamic, self.recompile, self.release_memory, self.share_outputs, self.input_dtype))


class CompiledModelCache:
def __init__(self):
self._cache = {}

def clear(self):
self._cache.clear()

Expand All @@ -72,7 +76,7 @@ def wrapper(*args, **kwargs):
cache = COMPILED_MODEL_CACHE._cache
if not recompile and cache_key in cache:
return cache[cache_key]
result = func(cache, *args, **kwargs)
result = func(*args, **kwargs)
cache[cache_key] = result
return result

Expand Down Expand Up @@ -109,12 +113,12 @@ def _(
)

if precomputed_scale is None:
results = model(weight)
results = model(weight.data)
compressed_weight, scale, zero_point = [Tensor(it) for it in results]
else:
inputs = [weight, precomputed_scale]
inputs = [weight.data, precomputed_scale.data]
if precomputed_zero_point is not None:
inputs += [precomputed_zero_point]
inputs += [precomputed_zero_point.data]
compressed_weight = Tensor(model(inputs)[0])
scale, zero_point = precomputed_scale, precomputed_zero_point

Expand All @@ -123,7 +127,12 @@ def _(

@calculate_quantized_dequantized_weight.register(WeightLoweringBackend.OV)
def _(
weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None, ov_model_params: Optional[OVModelParameters] = None, **kwargs
weight: Tensor,
config: WeightCompressionConfig,
scale: Tensor,
zero_point: Optional[Tensor] = None,
ov_model_params: Optional[OVModelParameters] = None,
**kwargs,
) -> Tensor:
weight_shape = weight.shape
scale_shape = scale.shape
Expand All @@ -134,17 +143,11 @@ def _(
if config.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT4_SYM]:
ov_model_params.dynamic = False

model = get_compress_decompress_weight_model(
config,
weight_shape,
scale_shape,
zero_point_shape,
ov_model_params
)
model = get_compress_decompress_weight_model(config, weight_shape, scale_shape, zero_point_shape, ov_model_params)

inputs = [weight, scale]
inputs = [weight.data, scale.data]
if zero_point is not None:
inputs.append(zero_point)
inputs.append(zero_point.data)
results = model(inputs)
decompressed_weight = [Tensor(it) for it in results][0]
return decompressed_weight
Expand Down Expand Up @@ -218,13 +221,7 @@ def _build_compress_decompress_model(
zero_point_shape: Optional[Tuple] = None,
):
ov_parameters, ov_results = _build_compress_model(
config,
ov_model_params,
weight_shape,
scale_shape,
zero_point_shape,
reduction_axes=None,
return_nodes=True
config, ov_model_params, weight_shape, scale_shape, zero_point_shape, reduction_axes=None, return_nodes=True
)
return _get_compress_decompress_model(
config,
Expand Down Expand Up @@ -283,7 +280,7 @@ def _build_compress_model(

num_groups_per_channel = channel_size // group_size
shape = list(weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis
shape[reduction_axes: reduction_axes + 1] = (num_groups_per_channel, group_size)
shape[reduction_axes : reduction_axes + 1] = (num_groups_per_channel, group_size)
weight = opset.reshape(weight, shape, special_zero=False)
reduction_axes += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
from nncf.tensor import functions as fns
from nncf.tensor.definitions import TensorDataType

from .dispatched_functions import do_int_quantization, calculate_quantized_dequantized_weight
from .common import calculate_integer_quantization_params
from .common import calculate_quantized_weight
from .common import do_int_dequantization
from .common import reshape_weight_for_grouped_quantization
from .dispatched_functions import calculate_quantized_dequantized_weight
from .dispatched_functions import do_int_quantization
from .weight_lowering_dispatcher import WeightLoweringBackend

from .common import reshape_weight_for_grouped_quantization, calculate_quantized_weight, calculate_integer_quantization_params, do_int_dequantization

ReductionAxes = Tuple[int, ...]


Expand Down Expand Up @@ -90,7 +93,12 @@ def _(

@calculate_quantized_dequantized_weight.register(WeightLoweringBackend.TENSOR)
def _(
weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None, invert_division=False, **kwargs
weight: Tensor,
config: WeightCompressionConfig,
scale: Tensor,
zero_point: Optional[Tensor] = None,
invert_division=False,
**kwargs,
) -> Tensor:
compressed_weight = calculate_quantized_weight(weight, config, scale, zero_point, invert_division)
decompressed_weight = do_int_dequantization(compressed_weight, scale, zero_point)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@
# limitations under the License.
from enum import Enum
from functools import wraps
from typing import Dict, Any, Callable, Optional
from typing import Any, Callable, Dict, Optional

from nncf.utils import is_openvino_available
from .ov_backend import do_int_quantization as do_int_quantization_ov
from .tensor_backend import do_int_quantization as do_int_quantization_tensor
from functools import singledispatch


class WeightLoweringBackend(Enum):
Expand Down

0 comments on commit 2cbbb01

Please sign in to comment.