-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
89 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
# | ||
# License: BSD | ||
|
||
__all__ = ['dprime', 'dprime_from_confusion_ova'] | ||
__all__ = ['dprime', 'dprime_from_samp', 'dprime_from_confusion'] | ||
|
||
import numpy as np | ||
from scipy.stats import norm | ||
|
@@ -33,9 +33,8 @@ def dprime(y_pred, y_true, **kwargs): | |
Returns | ||
------- | ||
dp: float or None | ||
d-prime, None if d-prime is undefined and raw d-prime value (``safedp=False``) | ||
is not requested (default). | ||
dp: float | ||
d-prime | ||
References | ||
---------- | ||
|
@@ -60,11 +59,11 @@ def dprime(y_pred, y_true, **kwargs): | |
pos = y_pred[i_pos] | ||
neg = y_pred[i_neg] | ||
|
||
dp = dprime_from_samp(pos, neg, bypass_nchk=True, **kwargs) | ||
dp = dprime_from_samp(pos, neg, **kwargs) | ||
return dp | ||
|
||
|
||
def dprime_from_samp(pos, neg, maxv=None, minv=None, safedp=True, bypass_nchk=False): | ||
def dprime_from_samp(pos, neg, max_value=np.inf, min_value=-np.inf): | ||
"""Computes the d-prime sensitivity index from positive and negative samples. | ||
Parameters | ||
|
@@ -75,26 +74,16 @@ def dprime_from_samp(pos, neg, maxv=None, minv=None, safedp=True, bypass_nchk=Fa | |
neg: array-like | ||
Negative sample values. | ||
maxv: float, optional | ||
Maximum possible d-prime value. If None (default), there's no limit on | ||
the maximum value. | ||
max_value: float, optional | ||
Maximum possible d-prime value. Default is ``np.inf``. | ||
minv: float, optional | ||
Minimum possible d-prime value. If None (default), there's no limit. | ||
safedp: bool, optional | ||
If True (default), this function will return None if the resulting d-prime | ||
value becomes non-finite. | ||
bypass_nchk: bool, optional | ||
If False (default), do not bypass the test to ensure that enough positive | ||
and negatives samples are there for the variance estimation. | ||
min_value: float, optional | ||
Minimum possible d-prime value. Default is ``-np.inf``. | ||
Returns | ||
------- | ||
dp: float or None | ||
d-prime, None if d-prime is undefined and raw d-prime value (``safedp=False``) | ||
is not requested (default). | ||
dp: float | ||
d-prime | ||
References | ||
---------- | ||
|
@@ -104,9 +93,10 @@ def dprime_from_samp(pos, neg, maxv=None, minv=None, safedp=True, bypass_nchk=Fa | |
pos = np.array(pos) | ||
neg = np.array(neg) | ||
|
||
if not bypass_nchk: | ||
assert pos.size > 1, 'Not enough positive samples to estimate the variance' | ||
assert neg.size > 1, 'Not enough negative samples to estimate the variance' | ||
if pos.size <= 1: | ||
raise ValueError('Not enough positive samples to estimate the variance') | ||
if neg.size <= 1: | ||
raise ValueError('Not enough negative samples to estimate the variance') | ||
|
||
pos_mean = pos.mean() | ||
neg_mean = neg.mean() | ||
|
@@ -117,82 +107,114 @@ def dprime_from_samp(pos, neg, maxv=None, minv=None, safedp=True, bypass_nchk=Fa | |
div = np.sqrt((pos_var + neg_var) / 2.) | ||
|
||
# from Dan's suggestion about clipping d' values... | ||
if maxv is None: | ||
maxv = np.inf | ||
if minv is None: | ||
minv = -np.inf | ||
|
||
dp = np.clip(num / div, minv, maxv) | ||
|
||
if safedp and not np.isfinite(dp): | ||
dp = None | ||
dp = np.clip(num / div, min_value, max_value) | ||
|
||
return dp | ||
|
||
|
||
def dprime_from_confusion_ova(M, fudge_mode=DEFAULT_FUDGE_MODE, \ | ||
fudge_fac=DEFAULT_FUDGE_FACTOR): | ||
def dprime_from_confusion(M, collation=None, fudge_mode=DEFAULT_FUDGE_MODE, \ | ||
fudge_factor=DEFAULT_FUDGE_FACTOR, max_value=np.inf, min_value=-np.inf): | ||
"""Computes the one-vs-all d-prime sensitivity index of the confusion matrix. | ||
This comment has been minimized.
Sorry, something went wrong. |
||
This function is mostly for when there is no access to internal representation | ||
and/or decision making (like human data). | ||
Parameters | ||
---------- | ||
M: array, shape = [n_classes (true), n_classes (pred)] | ||
M: array-like, shape = [n_classes (true), n_classes (pred)] | ||
Confusion matrix, where the element M_{rc} means the number of | ||
times when the classifier guesses that a test sample in the r-th class | ||
belongs to the c-th class. | ||
fudge_fac: float, optional | ||
collation: None (default) or array-like with shape = [n_grouping, n_classes] | ||
This comment has been minimized.
Sorry, something went wrong.
hahong
Author
|
||
Defines how to group entries in `M` to compute TPR and FPR. | ||
Entries shoule be {+1, 0, -1}. A row defines one instance of grouping, | ||
where +1, -1, and 0 designate the corresponding class as a | ||
positive, negative, and ignored class, respectively. For example, | ||
the following `collation` defines a 3-way one vs. rest grouping | ||
(given that `M` is a 3x3 matrix): | ||
[[+1, -1, -1], | ||
[-1, +1, -1], | ||
[-1, -1, +1]] | ||
If `None` (default), one vs. rest grouping is assumed. | ||
fudge_factor: float, optional | ||
A small factor to avoid non-finite numbers when TPR or FPR becomes 0 or 1. | ||
Default is 0.5. | ||
fudge_mode: str, optional | ||
Determins how to apply the fudge factor | ||
'always': always apply the fudge factor | ||
'correction': apply only when needed | ||
Determins how to apply the fudge factor. Can be one of: | ||
'correction': apply only when needed (default) | ||
'always': always apply the fudge factor | ||
'none': no fudging --- equivalent to ``fudge_factor=0`` | ||
max_value: float, optional | ||
Maximum possible d-prime value. Default is ``np.inf``. | ||
min_value: float, optional | ||
Minimum possible d-prime value. Default is ``-np.inf``. | ||
Returns | ||
------- | ||
dp: array, shape = [n_classes] | ||
Array of d-primes, each element corresponding to each class | ||
dp: array, shape = [n_grouping] | ||
Array of d-primes, where each element corresponds to each grouping | ||
defined by `collation`. | ||
References | ||
---------- | ||
http://en.wikipedia.org/wiki/D' | ||
http://en.wikipedia.org/wiki/Confusion_matrix | ||
""" | ||
|
||
# M: confusion matrix, row means true classes, col means predicted classes | ||
M = np.array(M) | ||
assert M.ndim == 2 | ||
assert M.shape[0] == M.shape[1] | ||
|
||
P = np.sum(M, axis=1) # number of positives, for each class | ||
N = np.sum(P) - P | ||
|
||
TP = np.diag(M) | ||
FP = np.sum(M, axis=0) - TP | ||
n_classes = M.shape[0] | ||
|
||
if fudge_mode == 'always': # always apply fudge factor | ||
TPR = (TP.astype('float') + fudge_fac) / (P + 2.*fudge_fac) | ||
FPR = (FP.astype('float') + fudge_fac) / (N + 2.*fudge_fac) | ||
if collation is None: | ||
# make it one vs. rest | ||
collation = -np.ones((n_classes, n_classes), dtype='int8') | ||
collation += 2 * np.eye(n_classes, dtype='int8') | ||
else: | ||
collation = np.array(collation, dtype='int8') | ||
assert collation.ndim == 2 | ||
assert collation.shape[1] == n_classes | ||
|
||
# P0: number of positives, for each class | ||
# P: number of positives, for each grouping | ||
# N: number of negatives, for each grouping | ||
# TP: number of true positives, for each grouping | ||
# FP: number of false positives, for each grouping | ||
P0 = np.sum(M, axis=1) | ||
P = np.array([np.sum(P0[coll == +1]) for coll in collation], dtype='float64') | ||
N = np.array([np.sum(P0[coll == -1]) for coll in collation], dtype='float64') | ||
TP = np.array([np.sum(M[coll == +1][:, coll == +1]) for coll in collation], dtype='float64') | ||
FP = np.array([np.sum(M[coll == -1][:, coll == +1]) for coll in collation], dtype='float64') | ||
|
||
# -- application of fudge factor | ||
if fudge_mode == 'none': # no fudging | ||
pass | ||
|
||
elif fudge_mode == 'always': # always apply fudge factor | ||
TP += fudge_factor | ||
FP += fudge_factor | ||
P += 2.*fudge_factor | ||
N += 2.*fudge_factor | ||
|
||
elif fudge_mode == 'correction': # apply fudge factor only when needed | ||
TP = TP.astype('float') | ||
FP = FP.astype('float') | ||
|
||
TP[TP == P] = P[TP == P] - fudge_fac # 100% correct | ||
TP[TP == 0] = fudge_fac # 0% correct | ||
FP[FP == N] = N[FP == N] - fudge_fac # always FAR | ||
FP[FP == 0] = fudge_fac # no false alarm | ||
|
||
TPR = TP / P | ||
FPR = FP / N | ||
TP[TP == P] = P[TP == P] - fudge_factor # 100% correct | ||
TP[TP == 0] = fudge_factor # 0% correct | ||
FP[FP == N] = N[FP == N] - fudge_factor # always FAR | ||
FP[FP == 0] = fudge_factor # no false alarm | ||
|
||
else: | ||
assert False, 'Not implemented' | ||
raise ValueError('Invalid fudge_mode') | ||
|
||
dp = norm.ppf(TPR) - norm.ppf(FPR) | ||
# if there's only two dp's then, it's must be "A" vs. "~A" task. If so, just give one value | ||
if len(dp) == 2: | ||
dp = np.array([dp[0]]) | ||
# -- done. compute the d' | ||
TPR = TP / P | ||
FPR = FP / N | ||
dp = np.clip(norm.ppf(TPR) - norm.ppf(FPR), min_value, max_value) | ||
|
||
return dp | ||
|
should drop "one-vs-all" in the DOC. Will be reflected in the next commit.