Skip to content

Commit

Permalink
add max_depth hyperparameter to Random Forest and Extra Trees (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Jul 21, 2020
1 parent 176a393 commit bbe304c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
14 changes: 10 additions & 4 deletions supervised/algorithms/extra_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def __init__(self, params):
self.model = ExtraTreesClassifier(
n_estimators=self.trees_in_step,
criterion=params.get("criterion", "gini"),
max_features=params.get("max_features", 0.6),
min_samples_split=params.get("min_samples_split", 30),
max_features=params.get("max_features", 0.8),
max_depth=params.get("max_depth", 6),
min_samples_split=params.get("min_samples_split", 4),
warm_start=True,
n_jobs=-1,
random_state=params.get("seed", 1),
Expand Down Expand Up @@ -65,8 +66,9 @@ def __init__(self, params):
self.model = ExtraTreesRegressor(
n_estimators=self.trees_in_step,
criterion=params.get("criterion", "mse"),
max_features=params.get("max_features", 0.8),
min_samples_split=params.get("min_samples_split", 4),
max_features=params.get("max_features", 0.6),
max_depth=params.get("max_depth", 6),
min_samples_split=params.get("min_samples_split", 30),
warm_start=True,
n_jobs=-1,
random_state=params.get("seed", 1),
Expand All @@ -81,12 +83,14 @@ def file_extension(self):
"criterion": ["gini", "entropy"],
"max_features": [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
"min_samples_split": [10, 20, 30, 40, 50],
"max_depth": [4,6,8,10,12]
}

classification_default_params = {
"criterion": "gini",
"max_features": 0.6,
"min_samples_split": 30,
"max_depth": 6
}

additional = {
Expand Down Expand Up @@ -133,12 +137,14 @@ def file_extension(self):
], # remove "mae" because it slows down a lot https://github.com/scikit-learn/scikit-learn/issues/9626
"max_features": [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
"min_samples_split": [10, 20, 30, 40, 50],
"max_depth": [4,6,8,10,12]
}

regression_default_params = {
"criterion": "mse",
"max_features": 0.6,
"min_samples_split": 30,
"max_depth": 6
}

regression_additional = {
Expand Down
6 changes: 6 additions & 0 deletions supervised/algorithms/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, params):
n_estimators=self.trees_in_step,
criterion=params.get("criterion", "gini"),
max_features=params.get("max_features", 0.8),
max_depth=params.get("max_depth", 6),
min_samples_split=params.get("min_samples_split", 4),
warm_start=True,
n_jobs=-1,
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self, params):
n_estimators=self.trees_in_step,
criterion=params.get("criterion", "mse"),
max_features=params.get("max_features", 0.8),
max_depth=params.get("max_depth", 6),
min_samples_split=params.get("min_samples_split", 4),
warm_start=True,
n_jobs=-1,
Expand All @@ -82,12 +84,14 @@ def file_extension(self):
"criterion": ["gini", "entropy"],
"max_features": [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
"min_samples_split": [10, 20, 30, 40, 50],
"max_depth": [4,6,8,10,12]
}

classification_default_params = {
"criterion": "gini",
"max_features": 0.6,
"min_samples_split": 30,
"max_depth": 6
}


Expand Down Expand Up @@ -137,12 +141,14 @@ def file_extension(self):
], # remove "mae" because it slows down a lot https://github.com/scikit-learn/scikit-learn/issues/9626
"max_features": [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
"min_samples_split": [10, 20, 30, 40, 50],
"max_depth": [4,6,8,10,12]
}

regression_default_params = {
"criterion": "mse",
"max_features": 0.6,
"min_samples_split": 30,
"max_depth": 6
}

regression_additional = {
Expand Down

0 comments on commit bbe304c

Please sign in to comment.