Skip to content

Commit

Permalink
Resnet18 example acc and performance alligned nncf/x86 inductor
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 3, 2024
1 parent 5a0d546 commit 681a4c5
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 14 deletions.
58 changes: 54 additions & 4 deletions examples/quantization_aware_training/torch/resnet18/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import re
import subprocess
import time
import warnings
from copy import deepcopy
from pathlib import Path
from typing import List, Tuple

import openvino as ov
import openvino.torch # noqa
import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torch.nn as nn
Expand All @@ -37,6 +39,7 @@
from torch.ao.quantization.quantize_pt2e import convert_pt2e
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.fx.passes.graph_drawer import FxGraphDrawer
from torch.jit import TracerWarning

import nncf
Expand All @@ -63,6 +66,18 @@
DATASET_PATH = "~/.cache/nncf/datasets"


def measure_time(model, example_inputs, num_iters):
with torch.no_grad():
model(*example_inputs)
total_time = 0
for i in range(0, num_iters):
start_time = time.time()
model(*example_inputs)
total_time += time.time() - start_time
average_time = (total_time / num_iters) * 1000
return average_time


def download_dataset() -> Path:
downloader = FastDownload(base=DATASET_PATH, archive="downloaded", data="extracted")
return downloader.get(DATASET_URL)
Expand Down Expand Up @@ -264,7 +279,7 @@ def transform_fn(data_item):

with torch.no_grad():
example_inputs = (torch.ones((1, 3, IMAGE_SIZE, IMAGE_SIZE)),)
exported_model = capture_pre_autograd_graph(model, example_inputs)
exported_model = capture_pre_autograd_graph(model.eval(), example_inputs)

NNCF_TORCH_FX = False

Expand All @@ -277,15 +292,50 @@ def transform_fn(data_item):

from tqdm import tqdm

for data in tqdm(islice(quantization_dataset.get_inference_data(), 3)):
for data in tqdm(islice(quantization_dataset.get_inference_data(), 300)):
prepared_model(data)
quantized_model = convert_pt2e(prepared_model)

g = FxGraphDrawer(quantized_model, "acc_resnet18_int8_native")
g.get_dot_graph().write_svg("acc_resnet18_int8_native.svg")
else:
quantized_model = nncf.quantize(exported_model, quantization_dataset)
g = FxGraphDrawer(quantized_model, "acc_resnet18_int8_nncf")
g.get_dot_graph().write_svg("acc_resnet18_int8_nncf.svg")

quantized_model = torch.compile(quantized_model)
acc1_int8_init = validate(val_loader, quantized_model, device)
# quantized_model = torch.compile(quantized_model)
# acc1_int8_init = validate(val_loader, quantized_model, device)
acc1_int8_init = validate(val_loader, torch.compile(quantized_model), device)
print(f"Accuracy@1 of initialized INT8 model: {acc1_int8_init:.3f}")

num_iters = 100

print("original model execution time: ", measure_time(model, example_inputs, num_iters))
native_optimized_model_fp32 = torch.compile(exported_model)
print(
"Torch Inductor FP32 model execution time: ",
measure_time(native_optimized_model_fp32, example_inputs, num_iters),
)
native_optimized_model_int8 = torch.compile(quantized_model)
print(
"Torch Inductor INT8 model execution time: ",
measure_time(native_optimized_model_int8, example_inputs, num_iters),
)

ov_optimized_model_fp32 = torch.compile(exported_model, backend="openvino")
print(
"Torch.compile OpenVINO FP32 model execution time: ",
measure_time(ov_optimized_model_fp32, example_inputs, num_iters),
)

ov_optimized_model_int8 = torch.compile(
quantized_model, backend="openvino", options={"model_caching": True, "cache_dir": "./model_cache"}
)
print(
"Torch.compile OpenVINO INT8 model execution time: ",
measure_time(ov_optimized_model_int8, example_inputs, num_iters),
)

return
###############################################################################
# Step 3: Fine tune quantized model
Expand Down
4 changes: 4 additions & 0 deletions nncf/experimental/torch_fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
# TODO: get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
if node_metatype is UnknownMetatype:
breakpoint()
# TODO: add layer attrs and support subtypes
# if node_metatype.get_subtypes():
# subtype = node_metatype.determine_subtype(
Expand Down Expand Up @@ -235,6 +237,8 @@ def get_edge_params(model, source_node: torch.fx.Node, source_nncf_node: NNCFNod
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
else:
if isinstance(source_node.meta["val"], tuple):
breakpoint()
tensor = source_node.meta["val"]
tensor_shape = tuple(tensor.shape)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
operator_metatypes.PTLayerNormMetatype,
operator_metatypes.PTModuleLayerNormMetatype,
# operator_metatypes.PTAddMetatype,
operator_metatypes.PTReshapeMetatype,
operator_metatypes.PTMulMetatype,
operator_metatypes.PTDivMetatype,
operator_metatypes.PTMatMulMetatype,
Expand Down Expand Up @@ -80,7 +81,7 @@
operator_metatypes.PTTransposeMetatype,
operator_metatypes.PTGatherMetatype,
operator_metatypes.PTScatterMetatype,
operator_metatypes.PTReshapeMetatype,
# operator_metatypes.PTReshapeMetatype,
operator_metatypes.PTSqueezeMetatype,
operator_metatypes.PTSplitMetatype,
operator_metatypes.PTExpandMetatype,
Expand Down
6 changes: 4 additions & 2 deletions nncf/experimental/torch_fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ def insert_one_qdq(
}
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default

# Quantized functions accepts only uint8 as an input
if target_point.target_type != TargetType.OPERATION_WITH_WEIGHTS and qparams["_dtype_"] == torch.int8:
qparams["_zero_point_"] = 125
qparams["_zero_point_"] = qparams["_zero_point_"] - qparams["_quant_min_"]
quants_len = qparams["_quant_max_"] - qparams["_quant_min_"]
qparams["_quant_min_"] = 0
qparams["_quant_max_"] = 255
qparams["_quant_max_"] = quants_len
qparams["_dtype_"] = torch.uint8
# TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize
# 2. replace activation_post_process node with quantize and dequantize
Expand Down
14 changes: 7 additions & 7 deletions torch_compile_ex_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_exported_model_from_nn_module(module, example_inputs):
return capture_pre_autograd_graph(module, example_inputs)


NNCF_IMPL = False
NNCF_IMPL = True


def get_qsetup(exported_model, example_inputs):
Expand Down Expand Up @@ -163,13 +163,13 @@ def main(model_name, num_iters):

converted_model = quantize(copy.deepcopy(model), example_inputs)

# print("original model execution time: ", measure_time(model, example_inputs, num_iters))
print("original model execution time: ", measure_time(model, example_inputs, num_iters))

# native_optimized_model_fp32 = torch.compile(model)
# print(
# "Torch Inductor FP32 model execution time: ",
# measure_time(native_optimized_model_fp32, example_inputs, num_iters),
# )
native_optimized_model_fp32 = torch.compile(model)
print(
"Torch Inductor FP32 model execution time: ",
measure_time(native_optimized_model_fp32, example_inputs, num_iters),
)

native_optimized_model_int8 = torch.compile(converted_model)
print(
Expand Down

0 comments on commit 681a4c5

Please sign in to comment.