Skip to content

Commit

Permalink
Implement astype for ov backend for bf16, u4, i4
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Oct 25, 2024
1 parent dee911a commit 851ad7f
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,13 @@ def transform_model(
mul_output = mul.output(0)
for target_input in const_node.output(0).get_target_inputs():
target_input.replace_source_output(mul_output)

# if compressed_weight.tensor.backend == TensorBackend.ov:
# if compressed_weight.tensor.dtype == TensorDataType.uint4:
# compressed_weight.tensor = compressed_weight.tensor.astype(TensorDataType.uint8)
# compressed_weight.tensor = compressed_weight.tensor.to_backend(TensorBackend.numpy)
if lora_correction_algo is not None and lora_correction_algo.is_applicable(wc_params):
if weight.backend == TensorBackend.ov:
if weight.dtype == TensorDataType.bfloat16:
weight = weight.astype(TensorDataType.float32)
weight = weight.to_backend(TensorBackend.numpy)
# TODO: cast int4 ov tensor too?
if compressed_weight.tensor.backend == TensorBackend.ov:
compressed_weight.tensor = compressed_weight.tensor.to_backend(TensorBackend.numpy)
if compressed_weight.zero_point.backend == TensorBackend.ov:
compressed_weight.zero_point = compressed_weight.zero_point.to_backend(TensorBackend.numpy)
adapters = lora_correction_algo.calculate_adapters(weight, compressed_weight, wc_params)
self.insert_adapters(wc_params, *adapters, int8_lora=lora_correction_algo.use_int8_adapters)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.results_caching import cache_results
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
from nncf.tensor.functions.ov import DTYPE_MAP as OV_DTYPE_MAP

TensorList = List[Tensor]
ModelCallable = Callable[[TensorList], TensorList]
Expand Down Expand Up @@ -57,7 +58,8 @@ def __hash__(self):

def run_model(ov_model_params: OVModelParameters, compiled_model: ov.CompiledModel, inputs: TensorList) -> TensorList:
# Returns results as numpy tensors
inputs = [inp.data for inp in inputs]
if any(isinstance(it, Tensor) for it in inputs):
inputs = [inp.data for inp in inputs]
outputs = compiled_model(
inputs, share_inputs=ov_model_params.share_inputs, share_outputs=ov_model_params.share_outputs
)
Expand All @@ -71,7 +73,8 @@ def run_model_via_infer_request(
ov_model_params: OVModelParameters, compiled_model: ov.CompiledModel, inputs: TensorList
) -> TensorList:
# Returns results as ov tensors
inputs = [inp.data for inp in inputs]
if any(isinstance(it, Tensor) for it in inputs):
inputs = [inp.data for inp in inputs]
infer_request = compiled_model.create_infer_request()
infer_request.infer(inputs, share_inputs=ov_model_params.share_inputs, share_outputs=ov_model_params.share_outputs)
outputs = [Tensor(infer_request.get_output_tensor(i)) for i in range(len(infer_request.results))]
Expand Down Expand Up @@ -100,7 +103,8 @@ def get_compress_weight_model(
if zero_point_shape is not None:
zero_point_shape = (-1,) * (len(zero_point_shape) - 1) + (1,)

ov_model_params.return_ov_tensors = config.num_bits == 4
if config.num_bits == 4:
ov_model_params.return_ov_tensors = True

return _build_compress_model(
config,
Expand Down Expand Up @@ -147,15 +151,7 @@ def _build_compress_model(
reduction_axes: Optional[Tuple] = None,
return_nodes: bool = False,
) -> ModelCallable:
if ov_model_params.input_dtype == TensorDataType.float32:
input_dtype = ov.Type.f32
elif ov_model_params.input_dtype == TensorDataType.float16:
input_dtype = ov.Type.f16
elif ov_model_params.input_dtype == TensorDataType.bfloat16:
input_dtype = ov.Type.bf16
else:
raise Exception
weight = opset.parameter(weight_shape, name="w", dtype=input_dtype)
weight = opset.parameter(weight_shape, name="w", dtype=OV_DTYPE_MAP[ov_model_params.input_dtype])
ov_parameters = [weight]

if scale_shape is not None:
Expand Down Expand Up @@ -212,7 +208,7 @@ def _build_compress_model(
if config.mode in [CompressWeightsMode.INT8_ASYM, config.mode.INT4_ASYM]:
dtype = ov.Type.u8 if config.mode == CompressWeightsMode.INT8_ASYM else ov.Type.u4
level_low = 0
level_high = 2 ** num_bits - 1
level_high = 2**num_bits - 1
compressed_w += zero_point
elif config.mode in [CompressWeightsMode.INT8_SYM, config.mode.INT4_SYM]:
dtype = ov.Type.i8 if config.mode == CompressWeightsMode.INT8_SYM else ov.Type.u4
Expand Down Expand Up @@ -272,3 +268,20 @@ def _build_compress_decompress_model(

run_fn = run_model_via_infer_request if ov_model_params.return_ov_tensors else run_model
return partial(run_fn, ov_model_params, compiled_model)


def get_astype_model(ov_model_params: OVModelParameters, arg_shape: Tuple, dtype: TensorDataType) -> ModelCallable:
if ov_model_params.dynamic_shapes:
arg_shape = (-1,) * len(arg_shape)
return _build_astype_model(ov_model_params, arg_shape, dtype)


@cache_results(OV_MODEL_CACHE)
def _build_astype_model(ov_model_params: OVModelParameters, arg_shape: Tuple, dtype: TensorDataType) -> ModelCallable:
arg = opset.parameter(arg_shape, dtype=OV_DTYPE_MAP[ov_model_params.input_dtype])
res = opset.convert(arg, OV_DTYPE_MAP[dtype])
model = ov.Model([res], [arg])
compiled_model = ov.compile_model(model, device_name="CPU")

run_fn = run_model_via_infer_request if ov_model_params.return_ov_tensors else run_model
return partial(run_fn, ov_model_params, compiled_model)
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ def compress_weight(
"""
if not config.is_integer():
if weight.backend == TensorBackend.ov:
if weight.dtype == TensorDataType.bfloat16:
weight = weight.astype(TensorDataType.float32)
weight = weight.to_backend(TensorBackend.numpy)

compressed_weight, scale = calculate_normalized_weight_and_fp4_scale(
Expand Down Expand Up @@ -461,8 +459,6 @@ def do_int_quantization(
# Reference implementation

if weight.backend == TensorBackend.ov:
if weight.dtype == TensorDataType.bfloat16:
weight = weight.astype(TensorDataType.float32)
weight = weight.to_backend(TensorBackend.numpy)

if weight.dtype != TensorDataType.float32:
Expand Down
1 change: 1 addition & 0 deletions nncf/results_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ def wrapper(*args, disable_caching=False, **kwargs):
return result

return wrapper

return decorator
2 changes: 2 additions & 0 deletions nncf/tensor/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class TensorDataType(Enum):
int32 = auto()
int64 = auto()
uint8 = auto()
uint4 = auto()
int4 = auto()

def is_float(self):
"""
Expand Down
53 changes: 36 additions & 17 deletions nncf/tensor/functions/ov.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,34 @@
TensorDataType.int32: ov.Type.i32,
TensorDataType.int64: ov.Type.i64,
TensorDataType.uint8: ov.Type.u8,
TensorDataType.uint4: ov.Type.u4,
TensorDataType.int4: ov.Type.i4,
}

DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()}


def _bf16_to_fp32(a: ov.Tensor) -> ov.Tensor:
assert a.get_element_type() == ov.Type.bf16 and a.data.dtype == np.float16
def _ov_astype(a: ov.Tensor, dtype: TensorDataType) -> ov.Tensor:
from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters
from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_astype_model

a = a.data.view(np.uint16)
a_dtype = DTYPE_MAP_REV[a.get_element_type()]
assert a_dtype in [TensorDataType.bfloat16, TensorDataType.uint4, TensorDataType.int4]

res = a.astype(np.uint32)
res = (
((res & 0x8000) << 16) # Move sign bit to bit 31
| ((res & 0x7F80) << 16) # Move exponent to bits 30-23
| ((res & 0x007F) << 16)
) # Move fraction to bits 22-0
res = res.view(np.float32)

res = ov.Tensor(res)
return res
model = get_astype_model(
OVModelParameters(
input_dtype=a_dtype,
dynamic_shapes=True,
recompile=False,
release_memory=True,
share_inputs=True,
share_outputs=True,
return_ov_tensors=True,
),
a.shape,
dtype,
)
return model([a])[0].data


@numeric.backend.register(ov.Tensor)
Expand All @@ -57,10 +65,10 @@ def _(a: ov.Tensor) -> TensorBackend:

@numeric.astype.register(ov.Tensor)
def _(a: ov.Tensor, dtype: TensorDataType) -> ov.Tensor:
if dtype == TensorDataType.bfloat16:
raise ValueError("Not supported conversion")
if a.get_element_type() == ov.Type.bf16:
a = _bf16_to_fp32(a)
a_dtype = DTYPE_MAP_REV[a.get_element_type()]
if a_dtype in [TensorDataType.bfloat16, TensorDataType.uint4, TensorDataType.int4]:
return _ov_astype(a, dtype)

return ov.Tensor(a.data.astype(NP_DTYPE_MAP[dtype]))


Expand All @@ -83,4 +91,15 @@ def _(a: ov.Tensor, shape: Union[int, Tuple[int, ...]]) -> ov.Tensor:
def _(a: ov.Tensor, b: TensorBackend) -> np.ndarray:
if b != TensorBackend.numpy:
raise ValueError("Not supported backend")

# Cannot convert bfloat16, uint4, int4 to numpy directly
a_dtype = DTYPE_MAP_REV[a.get_element_type()]
if a_dtype in [TensorDataType.bfloat16, TensorDataType.uint4, TensorDataType.int4]:
dtype = TensorDataType.float32
if a_dtype == TensorDataType.uint4:
dtype = TensorDataType.uint8
elif a_dtype == TensorDataType.int4:
dtype = TensorDataType.int8
a = _ov_astype(a, dtype)

return a.data

0 comments on commit 851ad7f

Please sign in to comment.