Skip to content

Commit

Permalink
add fixes for multi-class
Browse files Browse the repository at this point in the history
  • Loading branch information
Gscorreia89 committed Jul 22, 2024
1 parent b38a741 commit 05d4aac
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions pyChemometrics/ChemometricsPLSDA.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,28 @@ def fit(self, x, y, **fit_params):

else:
y_pred = self.predict(x)
accuracy = metrics.accuracy_score(y, y_pred)
precision = metrics.precision_score(y, y_pred, average='weighted')
recall = metrics.recall_score(y, y_pred, average='weighted')
misclassified_samples = np.where(y.ravel() != y_pred.ravel())[0]
f1_score = metrics.f1_score(y, y_pred, average='weighted')
conf_matrix = metrics.confusion_matrix(y, y_pred)
zero_oneloss = metrics.zero_one_loss(y, y_pred)

# Make dummy matrix into a label vector for scoring
if isDummy:
# y = np.where(y == 1)[1]
y_vec = np.where(y == 1)[1]
else:
y_vec = y

accuracy = metrics.accuracy_score(y_vec, y_pred)
precision = metrics.precision_score(y_vec, y_pred, average='weighted')
recall = metrics.recall_score(y_vec, y_pred, average='weighted')
misclassified_samples = np.where(y_vec.ravel() != y_pred.ravel())[0]
f1_score = metrics.f1_score(y_vec, y_pred, average='weighted')
conf_matrix = metrics.confusion_matrix(y_vec, y_pred)
zero_oneloss = metrics.zero_one_loss(y_vec, y_pred)
matthews_mcc = np.nan
roc_curve = list()
auc_area = list()

# Generate multiple ROC curves - one for each class the multiple class case
for predclass in range(self.n_classes):
current_roc = metrics.roc_curve(y, class_score[:, predclass], pos_label=predclass)
current_roc = metrics.roc_curve(y_vec, class_score[:, predclass], pos_label=predclass)
# Interpolate all ROC curves to a finite grid
# Makes it easier to average and compare multiple models - with CV in mind
tpr = current_roc[1]
Expand Down

0 comments on commit 05d4aac

Please sign in to comment.