Skip to content

Commit

Permalink
YOLO v8: check_export_not_strict
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 24, 2024
1 parent 8831330 commit 0ea85b7
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
4 changes: 2 additions & 2 deletions aa_torch_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,12 @@ def process_model(model_name: str):
##############################################################
# Process PT Quantize
##############################################################
# fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input)
fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input)

##############################################################
# Process NNCF FX Quantize
##############################################################
nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input)
# nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input)

##############################################################
# Process NNCF Quantize by PT
Expand Down
21 changes: 20 additions & 1 deletion examples/post_training_quantization/openvino/yolov8/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,24 @@ def main():
return fp_stats["metrics/mAP50-95(B)"], q_stats["metrics/mAP50-95(B)"], fp_model_perf, quantized_model_perf


def check_export_not_strict():
model = YOLO(f"{ROOT}/{MODEL_NAME}.pt")

# Prepare validation dataset and helper
validator, data_loader = prepare_validation_new(model, "coco128.yaml")

batch = next(iter(data_loader))
batch = validator.preprocess(batch)

model.model(batch["img"])
ex_model = torch.export.export(model.model, args=(batch["img"],), strict=False)
ex_model = capture_pre_autograd_graph(ex_model.module(), args=(batch["img"],))

fp_stats, total_images, total_objects = validate_fx(ex_model, tqdm(data_loader), validator)
print("Floating-point ex strict=False")
print_statistics(fp_stats, total_images, total_objects)


if __name__ == "__main__":
main()
check_export_not_strict()
# main()
3 changes: 2 additions & 1 deletion nncf/experimental/torch_fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def view_to_reshape(model: torch.fx.GraphModule):
continue
with model.graph.inserting_after(n):
reshape = model.graph.create_node("call_function", torch.ops.aten.reshape.default, tuple(n.args), {})
reshape.meta = n.meta

for user in list(n.users):
user.replace_input_with(n, reshape)
Expand Down Expand Up @@ -295,7 +296,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph:
GraphConverter.separate_conv_and_bias(model)
GraphConverter.unfold_scaled_dot_product_attention(model)
GraphConverter.view_to_reshape(model)
breakpoint()
# breakpoint()

nncf_graph = PTNNCFGraph()

Expand Down
5 changes: 3 additions & 2 deletions nncf/experimental/torch_fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def insert_one_qdq(
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default

# Quantized functions accepts only uint8 as an input
if target_point.target_type != TargetType.OPERATION_WITH_WEIGHTS and qparams["_dtype_"] == torch.int8:
raise RuntimeError("Wrong parameters: activations should always be uint8")
# if target_point.target_type != TargetType.OPERATION_WITH_WEIGHTS and qparams["_dtype_"] == torch.int8:
# breakpoint()
# raise RuntimeError("Wrong parameters: activations should always be uint8")

# TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize
# 2. replace activation_post_process node with quantize and dequantize
Expand Down

0 comments on commit 0ea85b7

Please sign in to comment.