Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
hadifawaz1999 committed Nov 13, 2024
1 parent 60e5646 commit f01518d
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 140 deletions.
55 changes: 48 additions & 7 deletions aeon/visualisation/results/_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def create_multi_comparison_matrix(
order_win_tie_loss="higher",
include_pvalue=True,
pvalue_test="wilcoxon",
pvalue_test_params=None,
pvalue_correction=None,
pvalue_threshold=0.05,
use_mean="mean-difference",
Expand Down Expand Up @@ -89,16 +90,33 @@ def create_multi_comparison_matrix(
pvalue_test: str, default = 'wilcoxon'
The statistical test to produce the pvalue stats. Currently only wilcoxon is
supported.
pvalue_test_params: dict, default = None,
The default parameter set for the pvalue_test used. If pvalue_test is set
to Wilcoxon, one should check the scipy.stats.wilcoxon parameters,
in the case Wilcoxon is set and this parameter is None, then the default setup
is {"zero_method": "pratt", "alternative": "greater"}.
pvalue_correction: str, default = None
Correction to use for the pvalue significant test, None or "Holm".
pvalue_threshold: float, default = 0.05
Threshold for considering a comparison is significant or not. If pvalue <
pvalue_threshhold -> comparison is significant.
use_mean: str, default = 'mean-difference'
The mean used to compare two estimators.
The mean used to compare two estimators. The only option available
is 'mean-difference' which is the difference between arithmetic mean
over all datasets.
order_stats: str, default = 'average-statistic'
The way to order the used_statistic, default setup orders by average
statistic over all datasets.
The options are:
===============================================================
method what it does
===============================================================
average-statistic average used_statistic over all datasets
average-rank average rank over all datasets
max-wins maximum number of wins over all datasets
amean-amean average over difference of use_mean
pvalue average pvalue over all comparates
================================================================
order_better: str, default = 'decreasing'
By which order to sort stats, from best to worse.
dataset_column: str, default = 'dataset_name'
Expand Down Expand Up @@ -175,6 +193,7 @@ def create_multi_comparison_matrix(
order_win_tie_loss=order_win_tie_loss,
include_pvalue=include_pvalue,
pvalue_test=pvalue_test,
pvalue_test_params=pvalue_test_params,
pvalue_correction=pvalue_correction,
pvalue_threshhold=pvalue_threshold,
use_mean=use_mean,
Expand Down Expand Up @@ -221,6 +240,7 @@ def _get_analysis(
order_win_tie_loss="higher",
include_pvalue=True,
pvalue_test="wilcoxon",
pvalue_test_params=None,
pvalue_correction=None,
pvalue_threshhold=0.05,
use_mean="mean-difference",
Expand Down Expand Up @@ -277,7 +297,16 @@ def _plot_1v1(
ax.set_xlabel(name_y, fontsize=fontsize)
ax.set_ylabel(name_x, fontsize=fontsize)

p_value = round(wilcoxon(x=x, y=y, zero_method="pratt")[1], 4)
if pvalue_test == "wilcoxon":
_pvalue_test_params = {}
if pvalue_test_params is None:
_pvalue_test_params = {"zero_method": "pratt", "alternative": "greater"}
else:
_pvalue_test_params = pvalue_test_params
p_value = round(wilcoxon(x=x, y=y, **_pvalue_test_params)[1], precision)
else:
raise ValueError("The test " + pvalue_test + " is not yet supported.")

legend_elements = [
mpl.lines.Line2D(
[], [], marker="o", color="blue", label=f"Win {win_x}", linestyle="None"
Expand Down Expand Up @@ -354,8 +383,10 @@ def _plot_1v1(
pairwise_content = _get_pairwise_content(
x=x,
y=y,
order_WinTieLoss=order_win_tie_loss,
include_pvalue=include_pvalue,
pvalue_test=pvalue_test,
pvalue_test_params=pvalue_test_params,
pvalue_threshhold=pvalue_threshhold,
use_mean=use_mean,
)
Expand Down Expand Up @@ -672,7 +703,10 @@ def _draw(
latex_cell = "\\rule{0em}{3ex} " + df_annotations_np[i, j].replace(
"\n", " \\\\ "
)
r = [str(round(_, 4)) for _ in cm(cm_norm(pairwise_matrix[i, j]))[:-1]]
r = [
str(round(_, precision))
for _ in cm(cm_norm(pairwise_matrix[i, j]))[:-1]
]
latex_row.append(
f"\\cellcolor[rgb]{{{','.join(r)}}}\\shortstack{{{latex_cell}}}"
)
Expand All @@ -697,7 +731,10 @@ def _draw(
"\n", " \\\\ "
)
s1 = f"{latex_bold}\\cellcolor[rgb]"
s2 = [str(round(_, 4)) for _ in cm(cm_norm(pairwise_matrix[i, j]))[:-1]]
s2 = [
str(round(_, precision))
for _ in cm(cm_norm(pairwise_matrix[i, j]))[:-1]
]
s = f"{s1}{{{','.join(s2)}}}\\shortstack{{{latex_cell}}}"
latex_row.append(s)

Expand Down Expand Up @@ -915,12 +952,11 @@ def _get_pairwise_content(
x,
y,
order_WinTieLoss="higher",
includeProbaWinTieLoss=False,
include_pvalue=True,
pvalue_test="wilcoxon",
pvalue_test_params=None,
pvalue_threshhold=0.05,
use_mean="mean-difference",
bayesian_rope=0.01,
):
content = {}

Expand All @@ -940,7 +976,12 @@ def _get_pairwise_content(

if include_pvalue:
if pvalue_test == "wilcoxon":
pvalue = wilcoxon(x=x, y=y, zero_method="pratt", alternative="greater")[1]
_pvalue_test_params = {}
if pvalue_test_params is None:
_pvalue_test_params = {"zero_method": "pratt", "alternative": "greater"}
else:
_pvalue_test_params = pvalue_test_params
pvalue = wilcoxon(x=x, y=y, **_pvalue_test_params)[1]
content["pvalue"] = pvalue

if pvalue_test == "wilcoxon":
Expand Down
Loading

0 comments on commit f01518d

Please sign in to comment.