Skip to content

Commit

Permalink
Yolo v8 quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 17, 2024
1 parent 9eaa0a4 commit d45a55f
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 15 deletions.
137 changes: 131 additions & 6 deletions examples/post_training_quantization/openvino/yolov8/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

os.environ["TORCHINDUCTOR_FREEZING"] = "1"

import re
import subprocess
import time
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Tuple

import numpy as np
import openvino as ov
import openvino.torch # noqa
import torch
from torch._export import capture_pre_autograd_graph
from torch.fx.passes.graph_drawer import FxGraphDrawer
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.data.converter import coco80_to_coco91_class
Expand All @@ -32,6 +41,36 @@
ROOT = Path(__file__).parent.resolve()


def measure_time(model, example_inputs, num_iters=500):
with torch.no_grad():
model(*example_inputs)
total_time = 0
for i in range(0, num_iters):
start_time = time.time()
model(*example_inputs)
total_time += time.time() - start_time
average_time = (total_time / num_iters) * 1000
return average_time


def validate_fx(
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
) -> Tuple[Dict, int, int]:
validator.seen = 0
validator.jdict = []
validator.stats = []
validator.confusion_matrix = ConfusionMatrix(nc=validator.nc)
for batch_i, batch in enumerate(data_loader):
if num_samples is not None and batch_i == num_samples:
break
batch = validator.preprocess(batch)
preds = model(batch["img"])
preds = validator.postprocess(preds)
validator.update_metrics(preds, batch)
stats = validator.get_stats()
return stats, validator.seen, validator.nt_per_class.sum()


def validate(
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
) -> Tuple[Dict, int, int]:
Expand Down Expand Up @@ -139,6 +178,66 @@ def transform_fn(data_item: Dict):
return quantized_model


NNCF_QUANTIZATION = True


def quantize_impl(exported_model, val_loader, validator):
def transform_fn(x):
batch = validator.preprocess(x)
return batch["img"]

calibration_dataset = nncf.Dataset(val_loader, transform_fn)
dir_name = str(Path(__file__).parent)
if NNCF_QUANTIZATION:
converted_model = nncf.quantize(
exported_model,
calibration_dataset,
ignored_scope=nncf.IgnoredScope(
types=["mul", "sub", "sigmoid"],
subgraphs=[
nncf.Subgraph(
inputs=["cat_13", "cat_14", "cat_15"],
outputs=["output"],
)
],
),
)
g = FxGraphDrawer(converted_model, "yolo_nncf_fx_int8")
g.get_dot_graph().write_svg(dir_name + "/yolo_nncf_fx_int8.svg")

quantized_model = torch.compile(converted_model, backend="openvino")
return quantized_model
else:
from torch.ao.quantization.quantize_pt2e import convert_pt2e
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config

quantizer = X86InductorQuantizer()
quantizer.set_global(get_default_x86_inductor_quantization_config())

prepared_model = prepare_pt2e(exported_model, quantizer)

for idx, batch in tqdm(enumerate(calibration_dataset.get_inference_data())):
if idx >= 300:
break
prepared_model(batch)

converted_model = convert_pt2e(prepared_model)

g = FxGraphDrawer(prepared_model, "yolo_torch_fx_int8")
g.get_dot_graph().write_svg(dir_name + "/yolo_torch_fx_int8.svg")
import torch._inductor.config as config

config.cpp_wrapper = True

quantized_model = torch.compile(converted_model)
return quantized_model


TORCH_FX = True


def main():
MODEL_NAME = "yolov8n"

Expand All @@ -150,13 +249,39 @@ def main():
validator, data_loader = prepare_validation(model, args)

# Convert to OpenVINO model
if TORCH_FX:
batch = next(iter(data_loader))
batch = validator.preprocess(batch)

with torch.no_grad():
# fp_stats, total_images, total_object = validate(model.model, tqdm(data_loader), validator)
# print("Floating-point model validation results:")
# print_statistics(fp_stats, total_images, total_objects)
model.model.eval()
model.model(batch["img"])
exported_model = capture_pre_autograd_graph(model.model, args=(batch["img"],))
quantized_model = quantize_impl(deepcopy(exported_model), data_loader, validator)

fp32_compiled_model = torch.compile(exported_model, backend="openvino")
fp32_stats, total_images, total_objects = validate_fx(fp32_compiled_model, tqdm(data_loader), validator)
# fp32_stats, total_images, total_objects = validate_fx(model.model, tqdm(data_loader), validator)
print("FP32 model validation results:")
print_statistics(fp32_stats, total_images, total_objects)

int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator)
print("INT8 model validation results:")
print_statistics(int8_stats, total_images, total_objects)

print("Start fp32 model benchmarking...")
fp32_latency = measure_time(fp32_compiled_model, (batch["img"],))
print(f"fp32 latency: {fp32_latency}")

print("Start int8 model benchmarking...")
int8_latency = measure_time(quantized_model, (batch["img"],))
print(f"int8 latency: {int8_latency}")
print(f"Speed up: {fp32_latency / int8_latency}")
return

example_inputs = torch.ones((1, 3, 640, 640))
# model.model = torch.compile(model.model)
# fx_model = model.export(format="torchscript")
with torch.no_grad():
model.model.eval()
capture_pre_autograd_graph(model.model, (example_inputs,))
ov_model, ov_model_path = prepare_openvino_model(model, MODEL_NAME)

# Quantize mode in OpenVINO representation
Expand Down
7 changes: 7 additions & 0 deletions nncf/common/hardware/configs/cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@
"weights": ["q8_w_sym", "q8_w_asym"]
}
},
{
"type": "Add",
"quantization": {
"activations": "q8_a",
"weights": ["q8_w_sym", "q8_w_asym"]
}
},
{
"type": "Multiply",
"quantization": {
Expand Down
16 changes: 11 additions & 5 deletions nncf/experimental/torch_fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
node_metatype = om.PTConstNoopMetatype
elif node.op in ("call_function",):
if hasattr(node.target, "overloadpacket"):
torch.nn.BatchNorm2d
node_type = str(node.target.overloadpacket).split(".")[1]
elif node.target.__name__ == "getitem":
node_type = "__getitem__"
else:
# TODO: get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
# if node_metatype is UnknownMetatype:
# breakpoint()
# TODO: add layer attrs and support subtypes
# if node_metatype.get_subtypes():
# subtype = node_metatype.determine_subtype(
Expand Down Expand Up @@ -208,10 +209,10 @@ def get_module_params_or_buffers():
for source_node in model.graph.nodes:

source_nncf_node = nncf_graph.get_node_by_name(source_node.name)
for dist_node in source_node.users:
for idx, dist_node in enumerate(source_node.users):
dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
model, source_node, source_nncf_node, dist_node
model, source_node, source_nncf_node, dist_node, idx
)

nncf_graph.add_edge_between_nncf_nodes(
Expand All @@ -226,14 +227,19 @@ def get_module_params_or_buffers():
return nncf_graph

@staticmethod
def get_edge_params(model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node):
# TODO: support cat
def get_edge_params(
model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int
):
output_port_id = 0
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
elif source_nncf_node.metatype is om.PTSplitMetatype:
tensor = source_node.meta["val"][output_idx]
# Assume every split outputs corresponds to an unique output_port_id
output_port_id = output_idx
else:
tensor = source_node.meta["val"]
tensor_shape = tuple(tensor.shape)
Expand Down
9 changes: 7 additions & 2 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
from nncf.torch.quantization.layers import AsymmetricQuantizer
from nncf.torch.quantization.layers import BaseQuantizer
Expand Down Expand Up @@ -118,6 +119,7 @@ def hw_config(self) -> HWConfig:

@property
def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
return DEFAULT_FX_QUANT_TRAIT_TO_OP_DICT

@staticmethod
Expand Down Expand Up @@ -320,8 +322,11 @@ def create_unified_scales_quantizers_insertion_commands(
)

# transformation = fake_quantize_insertion_tranformation_builder(quantizer, target_points)
transformation = qdq_insertion_tranformation_builder(quantizer, target_points)
return [FXApplyTransformationCommand(transformation)]
transformations = []
for tp in target_points:
transformation = qdq_insertion_tranformation_builder(quantizer, [tp])
transformations.append(FXApplyTransformationCommand(transformation))
return transformations

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
Expand Down
8 changes: 6 additions & 2 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ class PTGELUMetatype(PTOperatorMetatype):
@PT_OPERATOR_METATYPES.register()
class PTSILUMetatype(PTOperatorMetatype):
name = "SiluOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"]}
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"], NamespaceTarget.ATEN: ["silu_"]}


@PT_OPERATOR_METATYPES.register()
Expand Down Expand Up @@ -871,6 +871,7 @@ class PTSplitMetatype(PTOperatorMetatype):
NamespaceTarget.TORCH_NN_FUNCTIONAL: [],
NamespaceTarget.TORCH_TENSOR: ["split", "chunk", "unbind"],
NamespaceTarget.TORCH: ["split", "chunk", "unbind"],
NamespaceTarget.ATEN: ["split_with_sizes"],
}
hw_config_names = [HWConfigOpName.SPLIT, HWConfigOpName.CHUNK]

Expand Down Expand Up @@ -1036,7 +1037,10 @@ class PTSqrtMetatype(PTOperatorMetatype):
@PT_OPERATOR_METATYPES.register()
class PTInterpolateMetatype(PTOperatorMetatype):
name = "InterpolateOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"]}
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"],
NamespaceTarget.ATEN: ["upsample_nearest2d", "upsample_nearest_exact2d"],
}
hw_config_names = [HWConfigOpName.INTERPOLATE]
num_expected_input_edges = 1

Expand Down

0 comments on commit d45a55f

Please sign in to comment.