diff --git a/notebooks/tb_validation.ipynb b/notebooks/tb_validation.ipynb index 4be5ca5..523494b 100644 --- a/notebooks/tb_validation.ipynb +++ b/notebooks/tb_validation.ipynb @@ -365,7 +365,7 @@ "outputs": [], "source": [ "# Initialize a DataFrame to hold the results\n", - "results = pd.DataFrame(columns=['True positives', 'True negatives', 'False positives', 'False negatives', 'Not tested', 'Unknown'])\n", + "results = pd.DataFrame(columns=['True positives', 'True negatives', 'False positives', 'False negatives'])\n", "\n", "ass_loa = []\n", "n_counts = []\n", @@ -376,7 +376,7 @@ " false_negatives = ((km_tb_df[col1] == \"R\") & (km_tb_df[col2].isna())).sum()\n", " not_tested = (km_tb_df[col1].isin([\"ej testad\", \"Ej testad\", np.nan])).sum()\n", " unknown = (~km_tb_df[col1].isin([\"ej testad\", \"Ej testad\", \"R\", \"S\", np.nan])).sum()\n", - " total_count = len(km_tb_df)\n", + " total_count = true_positives + true_negatives + false_positives + false_negatives\n", " n_counts.append(total_count)\n", " if true_positives + true_negatives + false_positives + false_negatives == 0:\n", " accuracy = 0.0\n", @@ -397,12 +397,10 @@ " \"specificity\": specificity\n", " })\n", " results.loc[f'{col1} vs {col2}'] = [\n", - " true_positives / total_count * 100,\n", - " true_negatives / total_count * 100,\n", - " false_positives / total_count * 100,\n", - " false_negatives / total_count * 100,\n", - " not_tested / total_count * 100,\n", - " unknown / total_count * 100\n", + " true_positives,\n", + " true_negatives,\n", + " false_positives,\n", + " false_negatives\n", " ]\n", "\n", "# Write out csv\n", @@ -410,12 +408,12 @@ "\n", "# Plotting\n", "fig, ax = plt.subplots(figsize=(10, 8))\n", - "results.plot(kind='barh', stacked=True, ax=ax, color=['green', 'gold', 'orange', 'red', 'darkblue', 'lightblue'])\n", + "results.plot(kind='barh', stacked=True, ax=ax, color=['green', 'gold', 'orange', 'red'])\n", "\n", "# Customizing plot\n", "ax.set_xlabel('Percentage (%)')\n", "ax.set_title('AMR calling accuracy')\n", - "ax.legend(['True positives', 'True negatives', 'False positives', 'False negatives', 'Not tested', 'Unknown'], bbox_to_anchor=(1.05, 1), loc='upper left')\n", + "ax.legend(['True positives', 'True negatives', 'False positives', 'False negatives'], bbox_to_anchor=(1.05, 1), loc='upper left')\n", "ax.set_xlim(0, 110)\n", "for i in range(len(results)):\n", " ax.text(100, i, f'n={n_counts[i]}', ha='left', va='center', fontsize=10, color='black')\n",