Skip to content

Commit

Permalink
[PTQ] Reshape models to original shapes after quantization in calibra…
Browse files Browse the repository at this point in the history
…te.py (#2034)

### Changes

Models are reshaped back after quantization

### Reason for changes

To return quantized models with the same shapes as it was before the
quantization

### Related tickets

115226
  • Loading branch information
daniil-lyakhov authored Aug 8, 2023
1 parent 0b01c65 commit 41ca282
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tests/openvino/tools/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,10 @@ def maybe_reshape_model(model, dataset, subset_size, input_to_tensor_name):
model_inputs_shapes = {}
for input_output in model.inputs:
input_node = input_output.get_node()
model_inputs_shapes[input_to_tensor_name[input_node.friendly_name]] = tuple(input_node.partial_shape)
partial_shape = []
for dim in input_node.partial_shape:
partial_shape.append(Dimension(str(dim)))
model_inputs_shapes[input_to_tensor_name[input_node.friendly_name]] = tuple(partial_shape)

if len(dataset_inputs_shapes) != len(model_inputs_shapes):
raise RuntimeError(
Expand Down Expand Up @@ -720,13 +723,13 @@ def maybe_reshape_model(model, dataset, subset_size, input_to_tensor_name):
dynamic_dims[name].append(idx)

if not any(any(dict_.values()) for dict_ in [dynamic_dims, reshaped_static_dims]):
return model
return model, model_inputs_shapes

partial_shapes = {}
for name, shape in model_inputs_shapes.items():
for name, partial_shape in model_inputs_shapes.items():
dataset_first_shape = dataset_inputs_shapes[name].pop()
dims = []
for idx, d in enumerate(shape):
for idx, d in enumerate(partial_shape):
if idx in dynamic_dims[name]:
dim = Dimension(-1)
elif idx in reshaped_static_dims[name]:
Expand All @@ -741,7 +744,7 @@ def maybe_reshape_model(model, dataset, subset_size, input_to_tensor_name):
dims.append(dim)
partial_shapes[name] = PartialShape(dims)
model.reshape(partial_shapes)
return model
return model, model_inputs_shapes


# pylint: disable=protected-access
Expand Down Expand Up @@ -842,8 +845,9 @@ def quantize_model(xml_path, bin_path, accuracy_checker_config, quantization_imp
dataset = get_dataset(model_evaluator, quantization_parameters)
calibration_dataset = nncf.Dataset(dataset, transform_fn)

original_model_shapes = None
if get_allow_reshape_input(accuracy_checker_config):
ov_model = maybe_reshape_model(
ov_model, original_model_shapes = maybe_reshape_model(
ov_model,
calibration_dataset,
quantization_parameters.get("subset_size", 300),
Expand All @@ -852,6 +856,9 @@ def quantize_model(xml_path, bin_path, accuracy_checker_config, quantization_imp
model_evaluator.load_network([{"model": ov_model}])

quantized_model = nncf.quantize(ov_model, calibration_dataset, **quantization_parameters)
if original_model_shapes is not None:
quantized_model.reshape(original_model_shapes)

return quantized_model


Expand Down Expand Up @@ -895,8 +902,9 @@ def quantize_model_with_accuracy_control(
calibration_dataset = nncf.Dataset(dataset, transform_fn)
validation_dataset = ACDataset(model_evaluator, transform_fn)

original_model_shapes = None
if get_allow_reshape_input(accuracy_checker_config):
ov_model = maybe_reshape_model(
ov_model, original_model_shapes = maybe_reshape_model(
ov_model,
calibration_dataset,
quantization_parameters.get("subset_size", 300),
Expand Down Expand Up @@ -933,6 +941,8 @@ def quantize_model_with_accuracy_control(
else:
raise NotImplementedError(f"Unsupported implementation: {quantization_impl}")

if original_model_shapes is not None:
quantized_model.reshape(original_model_shapes)
return quantized_model


Expand Down

0 comments on commit 41ca282

Please sign in to comment.