From bd4f6feed217a43c9ee9be16f02fa8529953579a Mon Sep 17 00:00:00 2001 From: Lingxiao Ma Date: Fri, 16 Sep 2022 13:59:01 +0800 Subject: [PATCH] Fix ONNX debug script (#460) --- models/pytorch2onnx/ort_run_frozen_each_layer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/models/pytorch2onnx/ort_run_frozen_each_layer.py b/models/pytorch2onnx/ort_run_frozen_each_layer.py index 8bcd8ada7..efdf2838e 100644 --- a/models/pytorch2onnx/ort_run_frozen_each_layer.py +++ b/models/pytorch2onnx/ort_run_frozen_each_layer.py @@ -98,11 +98,10 @@ def check_shape(shape): providers = args.provider.split(",") if "CPUExecutionProvider" not in providers: providers.append("CPUExecutionProvider") +if 'CUDAExecutionProvider' in ort.get_available_providers() and 'CUDAExecutionProvider' not in providers: + providers = ['CUDAExecutionProvider'] + providers -ort_session = ort.InferenceSession(args.file, sess_options, providers=providers) - -if args.provider != '': - ort_session.set_providers([args.provider]) +ort_session = ort.InferenceSession(model.SerializeToString(), sess_options, providers=providers) print("Execution Providers:", ort_session.get_providers())