From 839367f8e411ec5a5ab25ffc80087f5af7f94877 Mon Sep 17 00:00:00 2001 From: Piotr Date: Thu, 2 Sep 2021 13:23:34 +0200 Subject: [PATCH] limit number of points in decision tree regressor visualization (#462) --- supervised/algorithms/decision_tree.py | 28 ++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/supervised/algorithms/decision_tree.py b/supervised/algorithms/decision_tree.py index a2284b4b..b30bfcbe 100644 --- a/supervised/algorithms/decision_tree.py +++ b/supervised/algorithms/decision_tree.py @@ -20,6 +20,7 @@ from sklearn.tree import _tree from dtreeviz.trees import dtreeviz +from supervised.utils.subsample import subsample def get_rules(tree, feature_names, class_names): @@ -204,14 +205,25 @@ def interpret( if explain_level == 0: return try: - - viz = dtreeviz( - self.model, - X_train, - y_train, - target_name="target", - feature_names=X_train.columns, - ) + # 250 is hard limit for number of points used in visualization + # if too many points are used then final SVG plot is very large (can be > 100MB) + if X_train.shape[0] > 250: + x, _, y, _ = subsample(X_train, y_train, REGRESSION, 250) + viz = dtreeviz( + self.model, + x, + y, + target_name="target", + feature_names=x.columns, + ) + else: + viz = dtreeviz( + self.model, + X_train, + y_train, + target_name="target", + feature_names=X_train.columns, + ) tree_file_plot = os.path.join(model_file_path, learner_name + "_tree.svg") viz.save(tree_file_plot) except Exception as e: