diff --git a/pyproject.toml b/pyproject.toml index e5a121c..340f95b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "riix" -version = "0.0.2" +version = "0.0.3" description = "vectorized implementations of online rating systems" readme = "README.md" license = {file = "LICENSE"} diff --git a/riix/metrics.py b/riix/metrics.py index 03b746b..8770760 100644 --- a/riix/metrics.py +++ b/riix/metrics.py @@ -11,6 +11,13 @@ def binary_accuracy(probs: np.ndarray, outcomes: np.ndarray) -> float: correct = outcomes[pos_mask].sum() + (1.0 - outcomes[neg_mask]).sum() + 0.5 * draw_mask.sum() return correct / probs.shape[0] +def accuracy_without_draws(probs: np.ndarray, outcomes: np.ndarray) -> float: + """compute binary accuracy after first filtering out rows where the label is a draw""" + draw_mask = outcomes == 0.5 + probs = probs[~draw_mask] + outcomes = outcomes[~draw_mask] + return binary_accuracy(probs, outcomes) + def accuracy_with_draws(probs: np.ndarray, outcomes: np.ndarray, draw_margin=0.0) -> float: """computes accuracy while allowing for ties""" @@ -39,6 +46,7 @@ def binary_metrics_suite(probs: np.ndarray, outcomes: np.ndarray): """a wrapper class for running a bunch of binary metrics""" metrics = { 'accuracy': binary_accuracy(probs, outcomes), + 'accuracy_without_draws' : accuracy_without_draws(probs, outcomes), 'log_loss': binary_log_loss(probs, outcomes), 'brier_score': brier_score(probs, outcomes), }