diff --git a/models/pytorch2onnx/ort_run_frozen_each_layer.py b/models/pytorch2onnx/ort_run_frozen_each_layer.py index 8bcd8ada7..efdf2838e 100644 --- a/models/pytorch2onnx/ort_run_frozen_each_layer.py +++ b/models/pytorch2onnx/ort_run_frozen_each_layer.py @@ -98,11 +98,10 @@ def check_shape(shape): providers = args.provider.split(",") if "CPUExecutionProvider" not in providers: providers.append("CPUExecutionProvider") +if 'CUDAExecutionProvider' in ort.get_available_providers() and 'CUDAExecutionProvider' not in providers: + providers = ['CUDAExecutionProvider'] + providers -ort_session = ort.InferenceSession(args.file, sess_options, providers=providers) - -if args.provider != '': - ort_session.set_providers([args.provider]) +ort_session = ort.InferenceSession(model.SerializeToString(), sess_options, providers=providers) print("Execution Providers:", ort_session.get_providers())