Skip to content

Commit

Permalink
End-to-end compression WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Sep 6, 2024
1 parent 8a83597 commit ac3ea02
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 68 deletions.
227 changes: 192 additions & 35 deletions nncf/openvino/quantization/compression_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Tuple
from typing import Optional, Tuple, List

import numpy as np
import openvino as ov
Expand All @@ -22,7 +22,30 @@
class OVCompressionPrimitiveCache:
def __init__(self):
self._compress_weight_model_cache = {}
self._compress_weight_end_to_end_model_cache = {}
self._compress_decompress_weight_model_cache = {}
self._compress_decompress_end_to_end_weight_model_cache = {}

def get_compress_weight_primitive_end_to_end(
self,
config: WeightCompressionConfig,
weight_shape: Tuple,
reduction_axes: Optional[Tuple],
invert_scale: Optional[bool] = False,
):
DYNAMIC_COMPRESSION = bool(int(os.environ.get("DYNAMIC_COMPRESSION", "0")))
if DYNAMIC_COMPRESSION:
weight_shape = (-1,) * len(weight_shape)

recompile = bool(int(os.environ.get("RECOMPILE", "0")))
if recompile:
return self._build_compress_model_end_to_end(config, weight_shape, reduction_axes, invert_scale)
key = (config.mode, config.num_bits, weight_shape, reduction_axes, invert_scale)
if key not in self._compress_weight_end_to_end_model_cache:
self._compress_weight_end_to_end_model_cache[key] = self._build_compress_model_end_to_end(
config, weight_shape, reduction_axes, invert_scale
)
return self._compress_weight_end_to_end_model_cache[key]

def get_compress_weight_primitive(
self,
Expand Down Expand Up @@ -55,28 +78,97 @@ def get_compress_decompress_weight_primitive(
self,
config: WeightCompressionConfig,
weight_shape: Tuple,
scale_shape: Tuple,
reduction_axes: Optional[Tuple] = None,
scale_shape: Optional[Tuple] = None,
zero_point_shape: Optional[Tuple] = None,
invert_scale: Optional[bool] = False,
):
DYNAMIC_COMPRESSION = bool(int(os.environ.get("DYNAMIC_COMPRESSION", "0")))
if DYNAMIC_COMPRESSION:
weight_shape = (-1,) * len(weight_shape)
scale_shape = (-1,) * (len(scale_shape) - 1) + (1,)
if scale_shape is not None:
scale_shape = (-1,) * (len(scale_shape) - 1) + (1,)
if zero_point_shape is not None:
zero_point_shape = (-1,) * (len(zero_point_shape) - 1) + (1,)

recompile = bool(int(os.environ.get("RECOMPILE", "0")))
if recompile:
return self._build_compress_decompress_model(config, weight_shape, scale_shape, zero_point_shape)
key = (config.mode, config.num_bits, weight_shape, scale_shape)
return self._build_compress_decompress_model(config, weight_shape, reduction_axes, scale_shape, zero_point_shape)
key = (config.mode, config.num_bits, weight_shape, invert_scale)
if reduction_axes is not None:
key += (reduction_axes,)
if scale_shape is not None:
key += (scale_shape,)
if zero_point_shape is not None:
key += (zero_point_shape,)
if key not in self._compress_decompress_weight_model_cache:
self._compress_decompress_weight_model_cache[key] = self._build_compress_decompress_model(
config, weight_shape, scale_shape, zero_point_shape
config, weight_shape, reduction_axes, scale_shape, zero_point_shape, invert_scale
)
return self._compress_decompress_weight_model_cache[key]

@staticmethod
def _build_compress_model_end_to_end(
config: WeightCompressionConfig,
weight_shape: Tuple,
reduction_axes: Optional[Tuple] = None,
invert_scale: Optional[bool] = False,
return_nodes: bool = False,
):
INPUT_DTYPE = os.environ.get("INPUT_DTYPE", "fp32")

if INPUT_DTYPE == "fp32":
input_dtype = ov.Type.f32
elif INPUT_DTYPE == "fp16":
input_dtype = ov.Type.f16
elif INPUT_DTYPE == "bf16":
input_dtype = ov.Type.bf16
else:
raise Exception
weight = opset.parameter(weight_shape, name="w", dtype=input_dtype)
parameters = [weight]

mode = config.mode
num_bits = config.num_bits
eps = np.finfo(np.float32).eps
if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]:
min_values = opset.reduce_min(weight, reduction_axes=reduction_axes,
keep_dims=True) # [a1, r, a2] -> [a1, 1, a2]
max_values = opset.reduce_max(weight, reduction_axes=reduction_axes,
keep_dims=True) # [a1, r, a2] -> [a1, 1, a2]
min_values, max_values = opset.convert(min_values, ov.Type.f32), opset.convert(max_values, ov.Type.f32)

level_low = 0
level_high = 2 ** num_bits - 1
levels = level_high - level_low + 1
scale = (max_values - min_values) / opset.constant(levels - 1, ov.Type.f32)
scale = opset.select(opset.abs(scale) < eps, eps, scale)

zero_point = opset.constant(level_low, ov.Type.f32) - opset.round(min_values / scale)
zero_point = opset.clamp(zero_point, level_low, level_high)
else:
zero_point = None
level_high = opset.constant(2 ** (num_bits - 1), ov.Type.f32)

w_abs_min = opset.abs(opset.reduce_min(weight, reduction_axes=reduction_axes, keep_dims=True))
w_max = opset.reduce_max(weight, reduction_axes=reduction_axes, keep_dims=True)
w_abs_min, w_max = opset.convert(w_abs_min, ov.Type.f32), opset.convert(w_max, ov.Type.f32)

scale = opset.select(w_abs_min >= w_max, w_abs_min, -w_max)
scale /= level_high
scale = opset.select(opset.abs(scale) < eps, eps, scale)

return OVCompressionPrimitiveCache._get_compress_model(
config,
parameters,
weight,
scale,
zero_point,
output_only_weight=False,
invert_scale=invert_scale,
return_nodes=return_nodes,
)

@staticmethod
def _build_compress_model(
config: WeightCompressionConfig,
Expand All @@ -87,15 +179,73 @@ def _build_compress_model(
return_nodes: bool = False,
):
INPUT_DTYPE = os.environ.get("INPUT_DTYPE", "fp32")
INT8_OUTPUT = bool(int(os.environ.get("INT8_OUTPUT", "0")))
SHARE_OUTPUTS = bool(int(os.environ.get("SHARE_OUTPUTS", "0")))

input_dtype = ov.Type.f32 if INPUT_DTYPE == "fp32" else ov.Type.f16 if INPUT_DTYPE == "fp16" else ov.Type.bf16
w = opset.parameter(weight_shape, name="w", dtype=input_dtype)
s = opset.parameter(scale_shape, name="s")
parameters = [w, s]
if INPUT_DTYPE == "fp32":
input_dtype = ov.Type.f32
elif INPUT_DTYPE == "fp16":
input_dtype = ov.Type.f16
elif INPUT_DTYPE == "bf16":
input_dtype = ov.Type.bf16
else:
raise Exception
weight = opset.parameter(weight_shape, name="w", dtype=input_dtype)
scale = opset.parameter(scale_shape, name="s")
parameters = [weight, scale]

zero_point = None
if config.mode in [CompressWeightsMode.INT8_ASYM, config.mode.INT4_ASYM]:
zero_point = opset.parameter(zero_point_shape, name="zp")
parameters.append(zero_point)

return OVCompressionPrimitiveCache._get_compress_model(
config,
parameters,
weight,
scale,
zero_point,
output_only_weight=True,
invert_scale=invert_scale,
return_nodes=return_nodes,
)

@staticmethod
def _build_compress_decompress_model_end_to_end(
config: WeightCompressionConfig,
weight_shape: Tuple,
reduction_axes: Optional[Tuple] = None,
invert_scale: Optional[bool] = False,
):
parameters, results = OVCompressionPrimitiveCache._build_compress_model_end_to_end(
config, weight_shape, reduction_axes, invert_scale, return_nodes=True
)
# `results` holds compressed weight, scale and, possibly, zero point
return OVCompressionPrimitiveCache._get_compress_decompress_model(config, parameters, results)

@staticmethod
def _build_compress_decompress_model(
config: WeightCompressionConfig,
weight_shape: Tuple,
scale_shape: Tuple,
zero_point_shape: Optional[Tuple] = None,
):
parameters, results = OVCompressionPrimitiveCache._build_compress_model(
config, weight_shape, scale_shape, zero_point_shape, return_nodes=True
)
# `results` holds only compressed weight
return OVCompressionPrimitiveCache._get_compress_decompress_model(config, parameters, results)

if input_dtype != ov.Type.f32:
@staticmethod
def _get_compress_model(
config: WeightCompressionConfig,
parameters: List[ov._pyopenvino.op.Parameter],
w: ov.runtime.Node,
s: ov.runtime.Node,
zp: Optional[ov.runtime.Node] = None,
output_only_weight: Optional[bool] = True,
invert_scale: Optional[bool] = None,
return_nodes: Optional[bool] = False,
):
if w.get_element_type() != ov.Type.f32:
w = opset.convert(w, ov.Type.f32)

compressed_w = w * (1 / s) if invert_scale else w / s
Expand All @@ -105,9 +255,6 @@ def _build_compress_model(
dtype = ov.Type.u8 if config.mode == CompressWeightsMode.INT8_ASYM else ov.Type.u4
level_low = 0
level_high = 2**num_bits - 1

zp = opset.parameter(zero_point_shape, name="zp")
parameters.append(zp)
compressed_w += zp
elif config.mode in [CompressWeightsMode.INT8_SYM, config.mode.INT4_SYM]:
dtype = ov.Type.i8 if config.mode == CompressWeightsMode.INT8_SYM else ov.Type.i4
Expand All @@ -116,39 +263,49 @@ def _build_compress_model(
else:
raise Exception

result = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights")
compressed_w = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights")

INT8_OUTPUT = bool(int(os.environ.get("INT8_OUTPUT", "0")))
if INT8_OUTPUT:
result = opset.convert(result, dtype)
compressed_w = opset.convert(compressed_w, dtype)

results = [compressed_w]
if not output_only_weight:
results.append(s)
if zp is not None:
results.append(zp)
if return_nodes:
return parameters, result
return parameters, results

model = ov.Model([result], parameters)
model = ov.Model(results, parameters)

compiled_model = ov.compile_model(model, device_name="CPU")

return lambda parameters: compiled_model(parameters, share_outputs=SHARE_OUTPUTS)[0]
SHARE_OUTPUTS = bool(int(os.environ.get("SHARE_OUTPUTS", "0")))
return compiled_model, lambda parameters: compiled_model(parameters, share_outputs=SHARE_OUTPUTS)

@staticmethod
def _build_compress_decompress_model(
def _get_compress_decompress_model(
config: WeightCompressionConfig,
weight_shape: Tuple,
scale_shape: Tuple,
zero_point_shape: Optional[Tuple] = None,
parameters: List[ov._pyopenvino.op.Parameter],
results: List[ov._pyopenvino.op.Parameter]
):
parameters, clamp = OVCompressionPrimitiveCache._build_compress_model(
config, weight_shape, scale_shape, zero_point_shape, return_nodes=True
)

if len(parameters) == 3:
_, s, zp = parameters
result = (clamp - zp) * s
if config.mode in [CompressWeightsMode.INT8_ASYM, config.mode.INT4_ASYM]:
if len(results) == 1:
compressed_w = results[0]
s, zp = parameters[1], parameters[2]
else:
compressed_w, s, zp = results
decompressed_w = (compressed_w - zp) * s
else:
s = parameters[1]
result = clamp * s
if len(results) == 1:
compressed_w = results[0]
s = parameters[1]
else:
compressed_w, s = results
decompressed_w = compressed_w * s

model = ov.Model([result], parameters)
model = ov.Model([decompressed_w], parameters)
compiled_model = ov.compile_model(model, device_name="CPU")

return lambda parameters: compiled_model(parameters)[0]
Expand Down
Loading

0 comments on commit ac3ea02

Please sign in to comment.