Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Aug 1, 2023
2 parents e13426b + 8a4a803 commit 293312c
Show file tree
Hide file tree
Showing 360 changed files with 29,815 additions and 21,216 deletions.
11 changes: 3 additions & 8 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,6 @@ jobs:
fi
displayName: 'install onnx'
- script: |
pip install flake8
displayName: 'install flake8'
- script: |
pip install $(onnxrt.version)
displayName: 'install onnxruntime'
Expand Down Expand Up @@ -334,11 +330,10 @@ jobs:
displayName: 'pytest-onnxmltools'
condition: eq(variables['run.example'], '1')
# Check flake8 after the tests to get more feedback.
# It is checked before the tests on the windows build.
- script: |
flake8 skl2onnx tests tests_onnxmltools
displayName: 'flake8'
python -m pip install ruff
ruff skl2onnx tests tests_onnxmltools
displayName: 'ruff'
- script: |
if [ '$(onnx.target_opset)' != '' ]
Expand Down
10 changes: 3 additions & 7 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,9 @@ jobs:
- script: |
call activate skl2onnxEnvironment
pip install flake8
displayName: 'install flake8'
- script: |
call activate skl2onnxEnvironment
flake8 skl2onnx tests tests_onnxmltools
displayName: 'flake8'
python -m pip install ruff
ruff skl2onnx tests tests_onnxmltools
displayName: 'ruff'
- script: |
call activate skl2onnxEnvironment
Expand Down
16 changes: 16 additions & 0 deletions .github/workflows/black-ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Black Format Checker
on: [push, pull_request]
jobs:
black-format-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: psf/black@stable
with:
options: "--diff --check"
src: "."
ruff-format-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

<p align="center"><img width="50%" src="docs/logo_main.png" /></p>

| Linux | Windows |
|-------|---------|
| [![Build Status](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status/sklearn-onnx-linux-conda-ci?branchName=master)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=5?branchName=master) | [![Build Status](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status/sklearn-onnx-win32-conda-ci?branchName=master)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=5?branchName=master)|
[![Build Status Linux](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status%2Fonnx.sklearn-onnx.linux.CI?branchName=refs%2Fpull%2F1009%2Fmerge)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=21&branchName=refs%2Fpull%2F1009%2Fmerge)

[![Build Status Windows](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status%2Fonnx.sklearn-onnx.win.CI?branchName=refs%2Fpull%2F1009%2Fmerge)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=22&branchName=refs%2Fpull%2F1009%2Fmerge)

[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

## Introduction
*sklearn-onnx* converts [scikit-learn](https://scikit-learn.org/stable/) models to [ONNX](https://github.com/onnx/onnx).
Once in the ONNX format, you can use tools like [ONNX Runtime](https://github.com/Microsoft/onnxruntime) for high performance scoring.
All converters are tested with [onnxruntime](https://onnxruntime.ai/).
Any external converter can be registered to convert scikit-learn pipeline
including models or transformers coming from external libraries.

## Documentation
Full documentation including tutorials is available at [https://onnx.ai/sklearn-onnx/](https://onnx.ai/sklearn-onnx/).
Expand Down
115 changes: 69 additions & 46 deletions benchmarks/bench_plot_onnxruntime_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pandas
from sklearn import config_context
from sklearn.tree import DecisionTreeClassifier

try:
# scikit-learn >= 0.22
from sklearn.utils._testing import ignore_warnings
Expand All @@ -29,14 +30,18 @@
# Implementations to benchmark.
##############################


def fcts_model(X, y, max_depth):
"DecisionTreeClassifier."
rf = DecisionTreeClassifier(max_depth=max_depth)
rf.fit(X, y)

initial_types = [('X', FloatTensorType([None, X.shape[1]]))]
onx = convert_sklearn(rf, initial_types=initial_types,
options={DecisionTreeClassifier: {'zipmap': False}})
initial_types = [("X", FloatTensorType([None, X.shape[1]]))]
onx = convert_sklearn(
rf,
initial_types=initial_types,
options={DecisionTreeClassifier: {"zipmap": False}},
)
f = BytesIO()
f.write(onx.SerializeToString())
content = f.getvalue()
Expand All @@ -51,30 +56,29 @@ def predict_skl_predict_proba(X, model=rf):
return rf.predict_proba(X)

def predict_onnxrt_predict(X, sess=sess):
return sess.run(outputs[:1], {'X': X})[0]
return sess.run(outputs[:1], {"X": X})[0]

def predict_onnxrt_predict_proba(X, sess=sess):
return sess.run(outputs[1:], {'X': X})[0]
return sess.run(outputs[1:], {"X": X})[0]

return {'predict': (predict_skl_predict,
predict_onnxrt_predict),
'predict_proba': (predict_skl_predict_proba,
predict_onnxrt_predict_proba)}
return {
"predict": (predict_skl_predict, predict_onnxrt_predict),
"predict_proba": (predict_skl_predict_proba, predict_onnxrt_predict_proba),
}


##############################
# Benchmarks
##############################


def allow_configuration(**kwargs):
return True


def bench(n_obs, n_features, max_depths, methods,
repeat=10, verbose=False):
def bench(n_obs, n_features, max_depths, methods, repeat=10, verbose=False):
res = []
for nfeat in n_features:

ntrain = 100000
X_train = np.empty((ntrain, nfeat))
X_train[:, :] = rand(ntrain, nfeat)[:, :].astype(np.float32)
Expand All @@ -88,15 +92,12 @@ def bench(n_obs, n_features, max_depths, methods,

for n in n_obs:
for method in methods:

fct1, fct2 = fcts[method]

if not allow_configuration(
n=n, nfeat=nfeat, max_depth=max_depth):
if not allow_configuration(n=n, nfeat=nfeat, max_depth=max_depth):
continue

obs = dict(n_obs=n, nfeat=nfeat,
max_depth=max_depth, method=method)
obs = dict(n_obs=n, nfeat=nfeat, max_depth=max_depth, method=method)

# creates different inputs to avoid caching in any ways
Xs = []
Expand Down Expand Up @@ -143,11 +144,11 @@ def bench(n_obs, n_features, max_depths, methods,
# Plots.
##############################


def plot_results(df, verbose=False):
nrows = max(len(set(df.max_depth)) * len(set(df.n_obs)), 2)
ncols = max(len(set(df.method)), 2)
fig, ax = plt.subplots(nrows, ncols,
figsize=(ncols * 4, nrows * 4))
fig, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))
pos = 0
row = 0
for n_obs in sorted(set(df.n_obs)):
Expand All @@ -156,31 +157,49 @@ def plot_results(df, verbose=False):
for method in sorted(set(df.method)):
a = ax[row, pos]
if row == ax.shape[0] - 1:
a.set_xlabel("N features", fontsize='x-small')
a.set_xlabel("N features", fontsize="x-small")
if pos == 0:
a.set_ylabel(
"Time (s) n_obs={}\nmax_depth={}".format(
n_obs, max_depth),
fontsize='x-small')

color = 'b'
subset = df[(df.method == method) & (df.n_obs == n_obs) &
(df.max_depth == max_depth)]
"Time (s) n_obs={}\nmax_depth={}".format(n_obs, max_depth),
fontsize="x-small",
)

color = "b"
subset = df[
(df.method == method)
& (df.n_obs == n_obs)
& (df.max_depth == max_depth)
]
if subset.shape[0] == 0:
continue
subset = subset.sort_values("nfeat")
if verbose:
print(subset)
label = "skl"
subset.plot(x="nfeat", y="time_skl", label=label, ax=a,
logx=True, logy=True, c=color, style='--')
subset.plot(
x="nfeat",
y="time_skl",
label=label,
ax=a,
logx=True,
logy=True,
c=color,
style="--",
)
label = "ort"
subset.plot(x="nfeat", y="time_ort", label=label, ax=a,
logx=True, logy=True, c=color)

a.legend(loc=0, fontsize='x-small')
subset.plot(
x="nfeat",
y="time_ort",
label=label,
ax=a,
logx=True,
logy=True,
c=color,
)

a.legend(loc=0, fontsize="x-small")
if row == 0:
a.set_title("method={}".format(method), fontsize='x-small')
a.set_title("method={}".format(method), fontsize="x-small")
pos += 1
row += 1

Expand All @@ -190,13 +209,14 @@ def plot_results(df, verbose=False):
@ignore_warnings(category=FutureWarning)
def run_bench(repeat=100, verbose=False):
n_obs = [1, 10, 100, 1000, 10000, 100000]
methods = ['predict', 'predict_proba']
methods = ["predict", "predict_proba"]
n_features = [1, 5, 10, 20, 50, 100, 200]
max_depths = [2, 5, 10, 20]

start = time()
results = bench(n_obs, n_features, max_depths, methods,
repeat=repeat, verbose=verbose)
results = bench(
n_obs, n_features, max_depths, methods, repeat=repeat, verbose=verbose
)
end = time()

results_df = pandas.DataFrame(results)
Expand All @@ -207,21 +227,24 @@ def run_bench(repeat=100, verbose=False):
return results_df


if __name__ == '__main__':
if __name__ == "__main__":
from datetime import datetime
import sklearn
import numpy
import onnx
import onnxruntime
import skl2onnx
df = pandas.DataFrame([
{"name": "date", "version": str(datetime.now())},
{"name": "numpy", "version": numpy.__version__},
{"name": "scikit-learn", "version": sklearn.__version__},
{"name": "onnx", "version": onnx.__version__},
{"name": "onnxruntime", "version": onnxruntime.__version__},
{"name": "skl2onnx", "version": skl2onnx.__version__},
])

df = pandas.DataFrame(
[
{"name": "date", "version": str(datetime.now())},
{"name": "numpy", "version": numpy.__version__},
{"name": "scikit-learn", "version": sklearn.__version__},
{"name": "onnx", "version": onnx.__version__},
{"name": "onnxruntime", "version": onnxruntime.__version__},
{"name": "skl2onnx", "version": skl2onnx.__version__},
]
)
df.to_csv("bench_plot_onnxruntime_decision_tree.time.csv", index=False)
print(df)
df = run_bench(verbose=True)
Expand Down
Loading

0 comments on commit 293312c

Please sign in to comment.