From 0ea85b7abe3da565fa008289a1f208b40d2922e5 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 24 Jun 2024 18:06:00 +0200 Subject: [PATCH] YOLO v8: check_export_not_strict --- aa_torch_fx.py | 4 ++-- .../openvino/yolov8/main.py | 21 ++++++++++++++++++- .../torch_fx/nncf_graph_builder.py | 3 ++- nncf/experimental/torch_fx/transformations.py | 5 +++-- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/aa_torch_fx.py b/aa_torch_fx.py index 2dabac92a52..269d9e228ab 100644 --- a/aa_torch_fx.py +++ b/aa_torch_fx.py @@ -403,12 +403,12 @@ def process_model(model_name: str): ############################################################## # Process PT Quantize ############################################################## - # fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) ############################################################## # Process NNCF FX Quantize ############################################################## - nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + # nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) ############################################################## # Process NNCF Quantize by PT diff --git a/examples/post_training_quantization/openvino/yolov8/main.py b/examples/post_training_quantization/openvino/yolov8/main.py index 7e2e7b93c33..ae4dc576c5c 100644 --- a/examples/post_training_quantization/openvino/yolov8/main.py +++ b/examples/post_training_quantization/openvino/yolov8/main.py @@ -371,5 +371,24 @@ def main(): return fp_stats["metrics/mAP50-95(B)"], q_stats["metrics/mAP50-95(B)"], fp_model_perf, quantized_model_perf +def check_export_not_strict(): + model = YOLO(f"{ROOT}/{MODEL_NAME}.pt") + + # Prepare validation dataset and helper + validator, data_loader = prepare_validation_new(model, "coco128.yaml") + + batch = next(iter(data_loader)) + batch = validator.preprocess(batch) + + model.model(batch["img"]) + ex_model = torch.export.export(model.model, args=(batch["img"],), strict=False) + ex_model = capture_pre_autograd_graph(ex_model.module(), args=(batch["img"],)) + + fp_stats, total_images, total_objects = validate_fx(ex_model, tqdm(data_loader), validator) + print("Floating-point ex strict=False") + print_statistics(fp_stats, total_images, total_objects) + + if __name__ == "__main__": - main() + check_export_not_strict() + # main() diff --git a/nncf/experimental/torch_fx/nncf_graph_builder.py b/nncf/experimental/torch_fx/nncf_graph_builder.py index 9c4ac1e581c..05a381d8840 100644 --- a/nncf/experimental/torch_fx/nncf_graph_builder.py +++ b/nncf/experimental/torch_fx/nncf_graph_builder.py @@ -236,6 +236,7 @@ def view_to_reshape(model: torch.fx.GraphModule): continue with model.graph.inserting_after(n): reshape = model.graph.create_node("call_function", torch.ops.aten.reshape.default, tuple(n.args), {}) + reshape.meta = n.meta for user in list(n.users): user.replace_input_with(n, reshape) @@ -295,7 +296,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: GraphConverter.separate_conv_and_bias(model) GraphConverter.unfold_scaled_dot_product_attention(model) GraphConverter.view_to_reshape(model) - breakpoint() + # breakpoint() nncf_graph = PTNNCFGraph() diff --git a/nncf/experimental/torch_fx/transformations.py b/nncf/experimental/torch_fx/transformations.py index be1872644e9..2af2ddf2469 100644 --- a/nncf/experimental/torch_fx/transformations.py +++ b/nncf/experimental/torch_fx/transformations.py @@ -105,8 +105,9 @@ def insert_one_qdq( dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default # Quantized functions accepts only uint8 as an input - if target_point.target_type != TargetType.OPERATION_WITH_WEIGHTS and qparams["_dtype_"] == torch.int8: - raise RuntimeError("Wrong parameters: activations should always be uint8") + # if target_point.target_type != TargetType.OPERATION_WITH_WEIGHTS and qparams["_dtype_"] == torch.int8: + # breakpoint() + # raise RuntimeError("Wrong parameters: activations should always be uint8") # TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize # 2. replace activation_post_process node with quantize and dequantize