Skip to content

Commit

Permalink
Fix #400, support multi:softmax objective (#442)
Browse files Browse the repository at this point in the history
* investigate xgboost issues

Signed-off-by: xavier dupré <[email protected]>

* fix softmax score

Signed-off-by: xavier dupré <[email protected]>

* lint

Signed-off-by: xavier dupré <[email protected]>

* test nightly build

Signed-off-by: xavier dupré <[email protected]>

* restore ci

Signed-off-by: xavier dupré <[email protected]>

* restore one file

Signed-off-by: xavier dupré <[email protected]>

Co-authored-by: xavier dupré <[email protected]>
  • Loading branch information
xadupre and sdpython authored Jan 21, 2021
1 parent 5805a9e commit ccddab5
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ jobs:
ONNXRT_PATH: onnxruntime==1.6.0
COREML_PATH: git+https://github.com/apple/[email protected]
xgboost.version: '>=1.0'
Python37-180-RT160-xgb11:
python.version: '3.7'
ONNX_PATH: onnx==1.8.0
ONNXRT_PATH: onnxruntime==1.6.0
COREML_PATH: git+https://github.com/apple/[email protected]
xgboost.version: '<1.2'
maxParallel: 3

steps:
Expand Down
10 changes: 9 additions & 1 deletion onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import json
import re
import numpy as np
from xgboost import XGBRegressor, XGBClassifier
from onnxconverter_common.data_types import FloatTensorType
Expand Down Expand Up @@ -53,8 +54,15 @@ def _get_attributes(booster):
# classification
kwargs['num_target'] = 0
if trees > ntrees > 0:
state = booster.__getstate__()
bstate = bytes(state['handle'])
reg = re.compile(b'(multi:[a-z]{1,15})')
objs = list(set(reg.findall(bstate)))
if len(objs) != 1:
raise RuntimeError(
"Unable to guess objective in {}.".format(objs))
kwargs['num_class'] = trees // ntrees
kwargs["objective"] = "multi:softprob"
kwargs["objective"] = objs[0].decode('ascii')
else:
kwargs['num_class'] = 1
kwargs["objective"] = "binary:logistic"
Expand Down
6 changes: 4 additions & 2 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def convert(scope, operator, container):
else:
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
attr_pairs['post_transform'] = "SOFTMAX"
# attr_pairs['base_values'] = [base_score for n in range(ncl)]
attr_pairs['base_values'] = [base_score for n in range(ncl)]
attr_pairs['class_ids'] = [v % ncl for v in attr_pairs['class_treeids']]

classes = xgb_node.classes_
Expand All @@ -264,8 +264,10 @@ def convert(scope, operator, container):
op_domain='ai.onnx.ml',
name=scope.get_unique_operator_name('TreeEnsembleClassifier'),
**attr_pairs)
elif objective == "multi:softprob":
elif objective in ("multi:softprob", "multi:softmax"):
ncl = len(js_trees) // params['n_estimators']
if objective == 'multi:softmax':
attr_pairs['post_transform'] = 'NONE'
container.add_node('TreeEnsembleClassifier', operator.input_full_names,
operator.output_full_names,
op_domain='ai.onnx.ml',
Expand Down
5 changes: 4 additions & 1 deletion onnxmltools/utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,12 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
if model_dict['objective'].startswith('binary'):
score = model.predict(datax)
prediction = [score > 0.5, numpy.vstack([1-score, score]).T]
elif model_dict['objective'].startswith('multi'):
elif model_dict['objective'].startswith('multi:softprob'):
score = model.predict(datax)
prediction = [score.argmax(axis=1), score]
elif model_dict['objective'].startswith('multi:softmax'):
score = model.predict(datax, output_margin=True)
prediction = [score.argmax(axis=1), score]
else:
prediction = [model.predict(datax)]
elif hasattr(model, "predict_proba"):
Expand Down
23 changes: 20 additions & 3 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Tests scilit-learn's tree-based methods' converters.
"""
import os
import sys
import unittest
import numpy as np
import pandas
Expand Down Expand Up @@ -192,7 +191,7 @@ def test_xgboost_booster_classifier_bin(self):
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
basename="XGBBoosterMCl")

def test_xgboost_booster_classifier_multiclass(self):
def test_xgboost_booster_classifier_multiclass_softprob(self):
x, y = make_classification(n_classes=3, n_features=5,
n_samples=100,
random_state=42, n_informative=3)
Expand All @@ -208,7 +207,25 @@ def test_xgboost_booster_classifier_multiclass(self):
dump_data_and_model(x_test.astype(np.float32),
model, model_onnx,
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
basename="XGBBoosterMCl")
basename="XGBBoosterMClSoftProb")

def test_xgboost_booster_classifier_multiclass_softmax(self):
x, y = make_classification(n_classes=3, n_features=5,
n_samples=100,
random_state=42, n_informative=3)
x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5,
random_state=42)

data = DMatrix(x_train, label=y_train)
model = train({'objective': 'multi:softmax',
'n_estimators': 3, 'min_child_samples': 1,
'num_class': 3}, data)
model_onnx = convert_xgboost(model, 'tree-based classifier',
[('input', FloatTensorType([None, x.shape[1]]))])
dump_data_and_model(x_test.astype(np.float32),
model, model_onnx,
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
basename="XGBBoosterMClSoftMax")

def test_xgboost_booster_classifier_reg(self):
x, y = make_classification(n_classes=2, n_features=5,
Expand Down

0 comments on commit ccddab5

Please sign in to comment.