Skip to content

Commit

Permalink
Changed default RF params
Browse files Browse the repository at this point in the history
  • Loading branch information
mese79 committed Sep 9, 2024
1 parent cbe7c64 commit 95133d2
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/featureforest/_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,19 @@ def create_label_stats_ui(self):
def create_train_ui(self):
tree_label = QLabel("Number of trees:")
self.num_trees_textbox = QLineEdit()
self.num_trees_textbox.setText("300")
self.num_trees_textbox.setText("450")
self.num_trees_textbox.setValidator(QIntValidator(1, 99999))

depth_label = QLabel("Max depth:")
self.max_depth_textbox = QLineEdit()
self.max_depth_textbox.setText("7")
self.max_depth_textbox.setText("9")
self.max_depth_textbox.setValidator(QIntValidator(0, 99999))
self.max_depth_textbox.setToolTip("set to 0 for unlimited depth.")

train_button = QPushButton("Train RF Model")
train_button.clicked.connect(self.train_model)
train_button.setMinimumWidth(150)

# self.sam_progress = QProgressBar()
# self.save_storage_button = QPushButton("Save SAM Embeddings")
# self.save_storage_button.clicked.connect(self.save_embeddings)
# self.save_storage_button.setMinimumWidth(150)
# self.save_storage_button.setMaximumWidth(150)
# self.save_storage_button.setEnabled(False)

self.model_status_label = QLabel("Model status:")

load_button = QPushButton("Load Model")
Expand Down Expand Up @@ -609,7 +602,10 @@ def train_model(self):
rf_classifier = RandomForestClassifier(
n_estimators=num_trees,
max_depth=max_depth,
min_samples_leaf=1,
class_weight="balanced",
min_samples_split=15,
min_samples_leaf=3,
max_features=25,
n_jobs=2 if os.cpu_count() < 5 else os.cpu_count() - 3,
verbose=1
)
Expand Down

0 comments on commit 95133d2

Please sign in to comment.