diff --git a/bangmetric/dprime.py b/bangmetric/dprime.py index bd15b87..3e5ac94 100644 --- a/bangmetric/dprime.py +++ b/bangmetric/dprime.py @@ -6,17 +6,17 @@ # # License: BSD -__all__ = ['dprime', 'dprime_ova_from_confusion'] +__all__ = ['dprime', 'dprime_from_confusion_ova'] import numpy as np from scipy.stats import norm DEFAULT_FUDGE_FACTOR = 0.5 DEFAULT_FUDGE_MODE = 'correction' -ATOL = 1e-7 +ATOL = 1e-6 -def dprime(y_pred, y_true): +def dprime(y_pred, y_true, **kwargs): """Computes the d-prime sensitivity index of the predictions. Parameters @@ -29,10 +29,14 @@ def dprime(y_pred, y_true): y_pred: array, shape = [n_samples] Predicted values (real). + kwargs: named arguments, optional + Passed to ``dprime_from_samp()``. + Returns ------- dp: float or None - d-prime, None if d-prime is undefined + d-prime, None if d-prime is undefined and raw d-prime value (``safedp=False``) + is not requested (default). References ---------- @@ -51,23 +55,78 @@ def dprime(y_pred, y_true): assert y_pred.ndim == 1 # -- actual computation - pos = y_true > 0 - neg = ~pos + i_pos = y_true > 0 + i_neg = ~i_pos + + pos = y_pred[i_pos] + neg = y_pred[i_neg] + + dp = dprime_from_samp(pos, neg, bypass_nchk=True, **kwargs) + return dp + + +def dprime_from_samp(pos, neg, maxv=None, minv=None, safedp=True, bypass_nchk=False): + """Computes the d-prime sensitivity index from positive and negative samples. + + Parameters + ---------- + pos: array-like + Positive sample values (e.g., raw projection values of the positive classifier). + + neg: array-like + Negative sample values. + + maxv: float, optional + Maximum possible d-prime value. If None (default), there's no limit on + the maximum value. + + minv: float, optional + Minimum possible d-prime value. If None (default), there's no limit. + + safedp: bool, optional + If True (default), this function will return None if the resulting d-prime + value becomes non-finite. + + bypass_nchk: bool, optional + If False (default), do not bypass the test to ensure that enough positive + and negatives samples are there for the variance estimation. - assert pos.sum() > 1, 'Not enough positives to estimate the variance' - assert neg.sum() > 1, 'Not enough negatives to estimate the variance' + Returns + ------- + dp: float or None + d-prime, None if d-prime is undefined and raw d-prime value (``safedp=False``) + is not requested (default). + + References + ---------- + http://en.wikipedia.org/wiki/D' + """ + + pos = np.array(pos) + neg = np.array(neg) - pos_mean = y_pred[pos].mean() - neg_mean = y_pred[neg].mean() - pos_var = y_pred[pos].var(ddof=1) - neg_var = y_pred[neg].var(ddof=1) + if not bypass_nchk: + assert pos.size > 1, 'Not enough positive samples to estimate the variance' + assert neg.size > 1, 'Not enough negative samples to estimate the variance' + + pos_mean = pos.mean() + neg_mean = neg.mean() + pos_var = pos.var(ddof=1) + neg_var = neg.var(ddof=1) num = pos_mean - neg_mean div = np.sqrt((pos_var + neg_var) / 2.) - if div == 0: + + # from Dan's suggestion about clipping d' values... + if maxv is None: + maxv = np.inf + if minv is None: + minv = -np.inf + + dp = np.clip(num / div, minv, maxv) + + if safedp and not np.isfinite(dp): dp = None - else: - dp = num / div return dp