Skip to content

Commit

Permalink
TreeEnsemble base values for the reference implementation (onnx#5665)
Browse files Browse the repository at this point in the history
This PR fixes one line failing on onnx#5569. What follows comes from this PR
description.

### Description

In the reference implementation of TreeEnsembleRegressor when a value is
provided for the base_values argument, this value replaces any
prediction. This can't be what's intended and does not match
onnxruntime. Instead, the base_value should be added to the prediction
after applying any aggregation.

### Motivation and Context

We are exporting regression trees into ONNX that have a non-zero
base_value as the baseline prediction for the tree. Prediction works as
expected in onnxruntime but not in the reference implementation. I
believe this is an oversight and propose the fix (plus tests) below. I
also think the documentation should be more explicit about what the
base_values argument does.

---------

Signed-off-by: Corwin Joy <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
Co-authored-by: Corwin Joy <[email protected]>
Co-authored-by: G. Ramalingam <[email protected]>
  • Loading branch information
3 people authored Oct 18, 2023
1 parent 5f908a9 commit 8fd6971
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/Changelog-ml.md
Original file line number Diff line number Diff line change
Expand Up @@ -1085,9 +1085,9 @@ This version of the operator has been available since version 3 of the 'ai.onnx.
<dt><tt>aggregate_function</tt> : string (default is SUM)</dt>
<dd>Defines how to aggregate leaf values within a target. <br>One of 'AVERAGE,' 'SUM,' 'MIN,' 'MAX.'</dd>
<dt><tt>base_values</tt> : list of floats</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>base_values_as_tensor</tt> : tensor</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>n_targets</tt> : int</dt>
<dd>The total number of targets.</dd>
<dt><tt>nodes_falsenodeids</tt> : list of ints</dt>
Expand Down
4 changes: 2 additions & 2 deletions docs/Operators-ml.md
Original file line number Diff line number Diff line change
Expand Up @@ -1038,9 +1038,9 @@ Other versions of this operator: <a href="Changelog-ml.md#ai.onnx.ml.TreeEnsembl
<dt><tt>aggregate_function</tt> : string (default is SUM)</dt>
<dd>Defines how to aggregate leaf values within a target. <br>One of 'AVERAGE,' 'SUM,' 'MIN,' 'MAX.'</dd>
<dt><tt>base_values</tt> : list of floats</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>base_values_as_tensor</tt> : tensor</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>n_targets</tt> : int</dt>
<dd>The total number of targets.</dd>
<dt><tt>nodes_falsenodeids</tt> : list of ints</dt>
Expand Down
4 changes: 2 additions & 2 deletions onnx/defs/traditionalml/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -998,12 +998,12 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
std::string("SUM"))
.Attr(
"base_values",
"Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
"Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"base_values_as_tensor",
"Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
"Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
Expand Down
4 changes: 3 additions & 1 deletion onnx/reference/ops/aionnxml/op_tree_ensemble_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ def _run( # type: ignore
)
if aggregate_function == "AVERAGE":
res /= n_trees

# Convention is to add base_values after aggregate function
if base_values is not None:
res[:, :] = np.array(base_values).reshape((1, -1))
res[:, :] += np.array(base_values).reshape((1, -1))

if post_transform in (None, "NONE"):
return (res,)
Expand Down
33 changes: 23 additions & 10 deletions onnx/test/reference_evaluator_ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np # type: ignore
from numpy.testing import assert_allclose # type: ignore
from parameterized import parameterized

from onnx import ONNX_ML, TensorProto, TypeProto, ValueInfoProto
from onnx.checker import check_model
Expand Down Expand Up @@ -757,7 +758,7 @@ def test_linear_classifier_unary(self):

@staticmethod
def _get_test_tree_ensemble_regressor(
aggregate_function, rule="BRANCH_LEQ", unique_targets=False
aggregate_function, rule="BRANCH_LEQ", unique_targets=False, base_values=None
):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])
Expand Down Expand Up @@ -786,6 +787,7 @@ def _get_test_tree_ensemble_regressor(
domain="ai.onnx.ml",
n_targets=1,
aggregate_function=aggregate_function,
base_values=base_values,
nodes_falsenodeids=[4, 3, 0, 0, 0, 2, 0, 4, 0, 0],
nodes_featureids=[0, 2, 0, 0, 0, 0, 0, 2, 0, 0],
nodes_hitrates=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
Expand Down Expand Up @@ -828,23 +830,34 @@ def _get_test_tree_ensemble_regressor(
check_model(onx)
return onx

@parameterized.expand(
[
(f"{agg}_{base_values}", base_values, agg)
for base_values in (None, [1.0])
for agg in ("SUM", "AVERAGE", "MIN", "MAX")
]
)
@unittest.skipIf(not ONNX_ML, reason="onnx not compiled with ai.onnx.ml")
def test_tree_ensemble_regressor(self):
def test_tree_ensemble_regressor(self, name, base_values, agg):
self.assertTrue(ONNX_ML)
del name # variable only used to print test name
x = np.arange(9).reshape((-1, 3)).astype(np.float32) / 10 - 0.5
expected_agg = {
"SUM": np.array([[0.576923], [0.576923], [0.576923]], dtype=np.float32),
"AVERAGE": np.array([[0.288462], [0.288462], [0.288462]], dtype=np.float32),
"MIN": np.array([[0.076923], [0.076923], [0.076923]], dtype=np.float32),
"MAX": np.array([[0.5], [0.5], [0.5]], dtype=np.float32),
}
for agg in ("SUM", "AVERAGE", "MIN", "MAX"):
expected = expected_agg[agg]
with self.subTest(aggregate_function=agg):
onx = self._get_test_tree_ensemble_regressor(agg)
self._check_ort(onx, {"X": x}, equal=True)
sess = ReferenceEvaluator(onx)
got = sess.run(None, {"X": x})
assert_allclose(expected, got[0], atol=1e-6)

expected = expected_agg[agg]
if base_values is not None:
expected += base_values[0]
with self.subTest(aggregate_function=agg):
onx = self._get_test_tree_ensemble_regressor(agg, base_values=base_values)
self._check_ort(onx, {"X": x}, equal=True)
sess = ReferenceEvaluator(onx)
got = sess.run(None, {"X": x})
assert_allclose(expected, got[0], atol=1e-6)

@unittest.skipIf(not ONNX_ML, reason="onnx not compiled with ai.onnx.ml")
def test_tree_ensemble_regressor_rule(self):
Expand Down

0 comments on commit 8fd6971

Please sign in to comment.