Skip to content

Commit

Permalink
Fix ONNX debug script (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
xysmlx authored Sep 16, 2022
1 parent 6537a59 commit bd4f6fe
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions models/pytorch2onnx/ort_run_frozen_each_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,10 @@ def check_shape(shape):
providers = args.provider.split(",")
if "CPUExecutionProvider" not in providers:
providers.append("CPUExecutionProvider")
if 'CUDAExecutionProvider' in ort.get_available_providers() and 'CUDAExecutionProvider' not in providers:
providers = ['CUDAExecutionProvider'] + providers

ort_session = ort.InferenceSession(args.file, sess_options, providers=providers)

if args.provider != '':
ort_session.set_providers([args.provider])
ort_session = ort.InferenceSession(model.SerializeToString(), sess_options, providers=providers)

print("Execution Providers:", ort_session.get_providers())

Expand Down

0 comments on commit bd4f6fe

Please sign in to comment.