Skip to content

Commit

Permalink
fix missing provider
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Oct 3, 2024
1 parent 9970498 commit d111679
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion tests/test_issues_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def test_issue_1129_lr(self):
model, x.values, options={"zipmap": False}
)
# Take predictions and probabilities with ONNX
sess = InferenceSession(onnx_model.SerializeToString())
sess = InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_prediction = sess.run(None, {"X": x_test.to_numpy()})
assert_almost_equal(sklearn_probs, onnx_prediction[1], decimal=decimal)
assert_almost_equal(sklearn_preds, onnx_prediction[0])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sklearn_pipeline_concat_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ def test_issue_712_svc_binary_empty(self):
target_opset=TARGET_OPSET,
options={CountVectorizer: {"keep_empty_string": True}},
)
with open("debug.onnx", "wb") as f:
f.write(onx.SerializeToString())
# with open("debug.onnx", "wb") as f:
# f.write(onx.SerializeToString())

Check notice

Code scanning / CodeQL

Commented-out code Note test

This comment appears to contain commented-out code.
sess = InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
Expand Down

0 comments on commit d111679

Please sign in to comment.