Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix converter for DecisionTreeClassifier if n_classses == 1 #1008

Merged
merged 8 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/examples/plot_cast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"""
import onnxruntime
import onnx
import numpy
import os
import math
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion docs/examples/plot_tfidfvectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import numpy
import onnxruntime as rt
from skl2onnx.common.data_types import StringTensorType
from skl2onnx import convert_sklearn
Expand Down
56 changes: 52 additions & 4 deletions skl2onnx/operator_converters/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numbers
import numpy as np
from onnx.numpy_helper import from_array
from ..common._apply_operation import (
apply_cast,
apply_concat,
Expand Down Expand Up @@ -124,7 +125,7 @@ def predict(
[indices_name, dummy_proba_name],
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)
else:
zero_name = scope.get_unique_variable_name("zero")
Expand Down Expand Up @@ -243,7 +244,7 @@ def _append_decision_output(
dpath,
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)

if n_out is None:
Expand Down Expand Up @@ -306,6 +307,53 @@ def convert_sklearn_decision_tree_classifier(
dtype = np.float32
op = operator.raw_operator
options = scope.get_options(op, dict(decision_path=False, decision_leaf=False))
if np.asarray(op.classes_).size == 1:
# The model was trained with one label.
# There is no need to build a tree.
if op.n_outputs_ != 1:
raise RuntimeError(
f"One training class and multiple outputs is not "
f"supported yet for class {op.__class__.__name__!r}."
)
if options["decision_path"] or options["decision_leaf"]:
raise RuntimeError(
f"One training class, option 'decision_path' "
f"or 'decision_leaf' are not supported for "
f"class {op.__class__.__name__!r}."
)

zero = scope.get_unique_variable_name("zero")
one = scope.get_unique_variable_name("one")
new_shape = scope.get_unique_variable_name("new_shape")
container.add_initializer(zero, onnx_proto.TensorProto.INT64, [1], [0])
container.add_initializer(one, onnx_proto.TensorProto.INT64, [1], [1])
container.add_initializer(new_shape, onnx_proto.TensorProto.INT64, [2], [-1, 1])
shape = scope.get_unique_variable_name("shape")
container.add_node("Shape", [operator.inputs[0].full_name], [shape])
shape_sliced = scope.get_unique_variable_name("shape_sliced")
container.add_node("Slice", [shape, zero, one, zero], [shape_sliced])

# labels
container.add_node(
"ConstantOfShape",
[shape_sliced],
[operator.outputs[0].full_name],
value=from_array(np.array([op.classes_[0]], dtype=np.int64)),
)

# probabilities
probas = scope.get_unique_variable_name("probas")
container.add_node(
"ConstantOfShape",
[shape_sliced],
[probas],
value=from_array(np.array([1], dtype=dtype)),
)
container.add_node(
"Reshape", [probas, new_shape], [operator.outputs[1].full_name]
)
return

if op.n_outputs_ == 1:
attrs = get_default_tree_classifier_attribute_pairs()
attrs["name"] = scope.get_unique_operator_name(op_type)
Expand Down Expand Up @@ -355,7 +403,7 @@ def convert_sklearn_decision_tree_classifier(
[operator.outputs[0].full_name, operator.outputs[1].full_name],
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)

n_out = 2
Expand Down Expand Up @@ -510,7 +558,7 @@ def convert_sklearn_decision_tree_regressor(
operator.outputs[0].full_name,
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)

options = scope.get_options(op, dict(decision_path=False, decision_leaf=False))
Expand Down
45 changes: 45 additions & 0 deletions tests/test_sklearn_classifiers_extreme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0

import unittest
import numpy as np

try:
from onnx.reference import ReferenceEvaluator
except ImportError:
ReferenceEvaluator = None
from sklearn.tree import DecisionTreeClassifier
from onnxruntime import InferenceSession
from skl2onnx import to_onnx
from test_utils import TARGET_OPSET


class TestSklearnClassifiersExtreme(unittest.TestCase):
def test_one_training_class(self):
x = np.eye(4, dtype=np.float32)
y = np.array([5, 5, 5, 5], dtype=np.int64)

cl = DecisionTreeClassifier()
cl = cl.fit(x, y)

expected = [cl.predict(x), cl.predict_proba(x)]
onx = to_onnx(cl, x, target_opset=TARGET_OPSET, options={"zipmap": False})

for cls in [
(lambda onx: ReferenceEvaluator(onx, verbose=0))
if ReferenceEvaluator is not None
else None,
lambda onx: InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
),
]:
if cls is None:
continue
sess = cls(onx)
res = sess.run(None, {"X": x})
self.assertEqual(len(res), len(expected))
for e, g in zip(expected, res):
self.assertEqual(e.tolist(), g.tolist())


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading