Skip to content

Commit

Permalink
[python] tests for plot tree functions and module_INSTALLED variables (
Browse files Browse the repository at this point in the history
…#1438)

* removed excess import

* added tests for plotting trees in Python

* refined module_INSTALLED mechanism

* added note about that create_tree_digraph is better than plot_tree
  • Loading branch information
StrikerRUS authored Jun 20, 2018
1 parent c184852 commit 5fe2bdd
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 72 deletions.
3 changes: 1 addition & 2 deletions .travis/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ if [[ $TASK == "if-else" ]]; then
exit 0
fi

conda install numpy nose scipy scikit-learn pandas matplotlib pytest
conda install numpy nose scipy scikit-learn pandas matplotlib python-graphviz pytest

if [[ $TASK == "sdist" ]]; then
cd $TRAVIS_BUILD_DIR/python-package && python setup.py sdist || exit -1
Expand Down Expand Up @@ -98,7 +98,6 @@ cd $TRAVIS_BUILD_DIR/python-package && python setup.py install --precompile || e
pytest $TRAVIS_BUILD_DIR || exit -1

if [[ $TASK == "regular" ]]; then
conda install python-graphviz
cd $TRAVIS_BUILD_DIR/examples/python-guide
sed -i'.bak' '/import lightgbm as lgb/a\
import matplotlib\
Expand Down
4 changes: 2 additions & 2 deletions examples/python-guide/plot_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import lightgbm as lgb
import pandas as pd

try:
if lgb.compat.MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
except ImportError:
else:
raise ImportError('You need to install matplotlib for plot_example.py.')

# load or create your dataset
Expand Down
17 changes: 17 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,30 @@ def json_default_with_numpy(obj):
"""pandas"""
try:
from pandas import Series, DataFrame
PANDAS_INSTALLED = True
except ImportError:
PANDAS_INSTALLED = False

class Series(object):
pass

class DataFrame(object):
pass

"""matplotlib"""
try:
import matplotlib
MATPLOTLIB_INSTALLED = True
except ImportError:
MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
import graphviz
GRAPHVIZ_INSTALLED = True
except ImportError:
GRAPHVIZ_INSTALLED = False

"""sklearn"""
try:
from sklearn.base import BaseEstimator
Expand Down
22 changes: 14 additions & 8 deletions python-package/lightgbm/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np

from .basic import Booster
from .compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from .sklearn import LGBMModel


Expand Down Expand Up @@ -69,9 +70,9 @@ def plot_importance(booster, ax=None, height=0.2,
ax : matplotlib.axes.Axes
The plot with model's feature importances.
"""
try:
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
except ImportError:
else:
raise ImportError('You must install matplotlib to plot importance.')

if isinstance(booster, LGBMModel):
Expand Down Expand Up @@ -173,9 +174,9 @@ def plot_metric(booster, metric=None, dataset_names=None,
ax : matplotlib.axes.Axes
The plot with metric's history over the training.
"""
try:
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
except ImportError:
else:
raise ImportError('You must install matplotlib to plot metric.')

if isinstance(booster, LGBMModel):
Expand Down Expand Up @@ -261,9 +262,9 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None,
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
"""
try:
if GRAPHVIZ_INSTALLED:
from graphviz import Digraph
except ImportError:
else:
raise ImportError('You must install graphviz to plot tree.')

def float2str(value, precision=None):
Expand Down Expand Up @@ -399,6 +400,11 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
show_info=None, precision=None):
"""Plot specified tree.
Note
----
It is preferable to use ``create_tree_digraph()`` because of its lossless quality
and returned objects can be also rendered and displayed directly inside a Jupyter notebook.
Parameters
----------
booster : Booster or LGBMModel
Expand Down Expand Up @@ -430,10 +436,10 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
ax : matplotlib.axes.Axes
The plot with single tree.
"""
try:
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
import matplotlib.image as image
except ImportError:
else:
raise ImportError('You must install matplotlib to plot tree.')

if ax is None:
Expand Down
1 change: 0 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# coding: utf-8
# pylint: skip-file
import os
import subprocess
import tempfile
import unittest

Expand Down
9 changes: 2 additions & 7 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from scipy.sparse import csr_matrix

try:
import pandas as pd
IS_PANDAS_INSTALLED = True
except ImportError:
IS_PANDAS_INSTALLED = False

try:
import cPickle as pickle
except ImportError:
Expand Down Expand Up @@ -478,8 +472,9 @@ def test_template(init_model=None, return_model=False):
for ret in other_ret:
self.assertAlmostEqual(ret_origin, ret, places=5)

@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed')
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self):
import pandas as pd
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
Expand Down
88 changes: 59 additions & 29 deletions tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,31 @@
import unittest

import lightgbm as lgb
from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

try:
if MATPLOTLIB_INSTALLED:
import matplotlib
matplotlib.use('Agg')
matplotlib_installed = True
except ImportError:
matplotlib_installed = False
if GRAPHVIZ_INSTALLED:
import graphviz


class TestBasic(unittest.TestCase):

@unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
def test_plot_importance(self):
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)

params = {
def setUp(self):
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
self.train_data = lgb.Dataset(self.X_train, self.y_train)
self.params = {
"objective": "binary",
"verbose": -1,
"num_leaves": 3
}
gbm0 = lgb.train(params, train_data, num_boost_round=10)

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_importance(self):
gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
ax0 = lgb.plot_importance(gbm0)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Feature importance')
Expand All @@ -35,7 +36,7 @@ def test_plot_importance(self):
self.assertLessEqual(len(ax0.patches), 30)

gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(X_train, y_train)
gbm1.fit(self.X_train, self.y_train)

ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
self.assertIsInstance(ax1, matplotlib.axes.Axes)
Expand All @@ -58,26 +59,55 @@ def test_plot_importance(self):
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b

@unittest.skip('Graphviz are not executables on Travis')
@unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed')
def test_plot_tree(self):
pass
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm.fit(self.X_train, self.y_train, verbose=False)

@unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
def test_plot_metrics(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)
test_data = lgb.Dataset(X_test, y_test, reference=train_data)
self.assertRaises(IndexError, lgb.plot_tree, gbm, tree_index=83)

params = {
"objective": "binary",
"metric": {"binary_logloss", "binary_error"},
"verbose": -1,
"num_leaves": 3
}
ax = lgb.plot_tree(gbm, tree_index=3, figsize=(15, 8), show_info=['split_gain'])
self.assertIsInstance(ax, matplotlib.axes.Axes)
w, h = ax.axes.get_figure().get_size_inches()
self.assertEqual(int(w), 15)
self.assertEqual(int(h), 8)

@unittest.skipIf(not GRAPHVIZ_INSTALLED, 'graphviz is not installed')
def test_create_tree_digraph(self):
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm.fit(self.X_train, self.y_train, verbose=False)

self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)

graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value'],
name='Tree4', node_attr={'color': 'red'})
graph.render(view=False)
self.assertIsInstance(graph, graphviz.Digraph)
self.assertEqual(graph.name, 'Tree4')
self.assertEqual(graph.filename, 'Tree4.gv')
self.assertEqual(len(graph.node_attr), 1)
self.assertEqual(graph.node_attr['color'], 'red')
self.assertEqual(len(graph.graph_attr), 0)
self.assertEqual(len(graph.edge_attr), 0)
graph_body = ''.join(graph.body)
self.assertIn('threshold', graph_body)
self.assertIn('split_feature_name', graph_body)
self.assertNotIn('split_feature_index', graph_body)
self.assertIn('leaf_index', graph_body)
self.assertIn('split_gain', graph_body)
self.assertIn('internal_value', graph_body)
self.assertNotIn('internal_count', graph_body)
self.assertNotIn('leaf_count', graph_body)

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_metrics(self):
test_data = lgb.Dataset(self.X_test, self.y_test, reference=self.train_data)
self.params.update({"metric": {"binary_logloss", "binary_error"}})

evals_result0 = {}
gbm0 = lgb.train(params, train_data,
valid_sets=[train_data, test_data],
gbm0 = lgb.train(self.params, self.train_data,
valid_sets=[self.train_data, test_data],
valid_names=['v1', 'v2'],
num_boost_round=10,
evals_result=evals_result0,
Expand All @@ -91,14 +121,14 @@ def test_plot_metrics(self):
ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])

evals_result1 = {}
gbm1 = lgb.train(params, train_data,
gbm1 = lgb.train(self.params, self.train_data,
num_boost_round=10,
evals_result=evals_result1,
verbose_eval=False)
self.assertRaises(ValueError, lgb.plot_metric, evals_result1)

gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm2.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
gbm2.fit(self.X_train, self.y_train, eval_set=[(self.X_test, self.y_test)], verbose=False)
ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
Expand Down
42 changes: 19 additions & 23 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
sklearn_at_least_019 = True
except ImportError:
sklearn_at_least_019 = False
try:
import pandas as pd
IS_PANDAS_INSTALLED = True
except ImportError:
IS_PANDAS_INSTALLED = False


def multi_error(y_true, y_pred):
Expand Down Expand Up @@ -182,26 +177,27 @@ def test_sklearn_backward_compatibility(self):
y_pred_2 = clf_2.fit(X_train, y_train).predict_proba(X_test)
np.testing.assert_allclose(y_pred_1, y_pred_2)

# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1
@unittest.skipIf(not sklearn_at_least_019, 'scikit-learn version is less than 0.19')
def test_sklearn_integration(self):
# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1
if sklearn_at_least_019:
# we cannot use `check_estimator` directly since there is no skip test mechanism
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier),
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)):
check_parameters_default_constructible(name, estimator)
check_no_fit_attributes_set_in_init(name, estimator)
# we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833)
estimator = estimator(min_child_samples=1, min_data_in_bin=1)
for check in _yield_all_checks(name, estimator):
if check.__name__ == 'check_estimators_nan_inf':
continue # skip test because LightGBM deals with nan
try:
check(name, estimator)
except SkipTest as message:
warnings.warn(message, SkipTestWarning)

@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed')
# we cannot use `check_estimator` directly since there is no skip test mechanism
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier),
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)):
check_parameters_default_constructible(name, estimator)
check_no_fit_attributes_set_in_init(name, estimator)
# we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833)
estimator = estimator(min_child_samples=1, min_data_in_bin=1)
for check in _yield_all_checks(name, estimator):
if check.__name__ == 'check_estimators_nan_inf':
continue # skip test because LightGBM deals with nan
try:
check(name, estimator)
except SkipTest as message:
warnings.warn(message, SkipTestWarning)

@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self):
import pandas as pd
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
Expand Down

0 comments on commit 5fe2bdd

Please sign in to comment.