diff --git a/src/pqm/pqm.py b/src/pqm/pqm.py index ed1601f..a1d9099 100644 --- a/src/pqm/pqm.py +++ b/src/pqm/pqm.py @@ -7,9 +7,16 @@ __all__ = ("pqm_chi2", "pqm_pvalue") -def _pqm_test(x_samples: np.ndarray, y_samples: np.ndarray, num_refs: int, whiten: bool): +def _pqm_test( + x_samples: np.ndarray, + y_samples: np.ndarray, + num_refs: int, + whiten: bool, + x_frac: Optional[float] = None, +): """ - Helper function to perform the PQM test and return the results from chi2_contingency. + Helper function to perform the PQM test and return the results from + chi2_contingency. Parameters ---------- @@ -20,7 +27,14 @@ def _pqm_test(x_samples: np.ndarray, y_samples: np.ndarray, num_refs: int, white num_refs : int Number of reference samples to use. whiten : bool - If True, whiten the samples by subtracting the mean and dividing by the standard deviation. + If True, whiten the samples by subtracting the mean and dividing by the + standard deviation. + x_frac : float + Fraction of x_samples to use as reference samples. ``x_frac = 1`` will + use only x_samples as reference samples, ``x_frac = 0`` will use only + y_samples as reference samples. Ideally, ``x_frac = len(x_samples) / + (len(x_samples) + len(y_samples))`` which is what is done for x_frac = + None (default). Returns ------- @@ -41,12 +55,34 @@ def _pqm_test(x_samples: np.ndarray, y_samples: np.ndarray, num_refs: int, white y_samples = (y_samples - mean) / std x_samples = (x_samples - mean) / std - refs = np.random.choice(len(y_samples), num_refs, replace=False) - N = np.arange(len(y_samples)) - N[refs] = -1 - N = N[N >= 0] - refs, y_samples = y_samples[refs], y_samples[N] - + # Determine fraction of x_samples to use as reference samples + if x_frac is None: + x_frac = len(x_samples) / (len(x_samples) + len(y_samples)) + + # Collect reference samples from x_samples + if x_frac > 0: + xrefs = np.random.choice(len(x_samples), int(x_frac * num_refs), replace=False) + N = np.arange(len(x_samples)) + N[xrefs] = -1 + N = N[N >= 0] + xrefs, x_samples = x_samples[xrefs], x_samples[N] + else: + xrefs = np.zeros((0,) + x_samples.shape[1:]) + + # Collect reference samples from y_samples + if x_frac < 1: + yrefs = np.random.choice(len(y_samples), int((1.0 - x_frac) * num_refs), replace=False) + N = np.arange(len(y_samples)) + N[yrefs] = -1 + N = N[N >= 0] + yrefs, y_samples = y_samples[yrefs], y_samples[N] + else: + yrefs = np.zeros((0,) + y_samples.shape[1:]) + + # Join the full set of reference samples + refs = np.concatenate([xrefs, yrefs], axis=0) + + # Build KDtree to measure distances tree = KDTree(refs) idx = tree.query(x_samples, k=1, workers=-1)[1] @@ -68,34 +104,50 @@ def pqm_pvalue( num_refs: int = 100, re_tessellation: Optional[int] = None, whiten: bool = False, + x_frac: Optional[float] = None, ): """ - Perform the PQM test of the null hypothesis that `x_samples` and `y_samples` are drawn form the same distribution. + Perform the PQM test of the null hypothesis that `x_samples` and `y_samples` + are drawn form the same distribution. Parameters ---------- x_samples : np.ndarray - Samples from the first distribution, test samples. Must have shape (N, *D) N is the number of x samples, and D is the dimensionality of the samples. + Samples from the first distribution, test samples. Must have shape (N, + *D) N is the number of x samples, and D is the dimensionality of the + samples. y_samples : np.ndarray - Samples from the second distribution, reference samples. Must have shape (M, *D) M is the number of y samples, and D is the dimensionality of the samples. + Samples from the second distribution, reference samples. Must have shape + (M, *D) M is the number of y samples, and D is the dimensionality of the + samples. num_refs : int - Number of reference samples to use. Note that these will be drawn from y_samples, and then removed from the y_samples array. + Number of reference samples to use. Note that these will be drawn from + y_samples, and then removed from the y_samples array. re_tessellation : Optional[int] - Number of times pqm_pvalue is called, re tesselating the space. No re_tessellation if None (default). + Number of times pqm_pvalue is called, re-tesselating the space. No + re_tessellation if None (default). whiten : bool - If True, whiten the samples by subtracting the mean and dividing by the standard deviation. + If True, whiten the samples by subtracting the mean and dividing by the + standard deviation. + x_frac : float + Fraction of x_samples to use as reference samples. ``x_frac = 1`` will + use only x_samples as reference samples, ``x_frac = 0`` will use only + y_samples as reference samples. Ideally, ``x_frac = len(x_samples) / + (len(x_samples) + len(y_samples))`` which is what is done for x_frac = + None (default). Returns ------- float or list - pvalue(s). Null hypothesis that both samples are drawn from the same distribution. + pvalue(s). Null hypothesis that both samples are drawn from the same + distribution. """ if re_tessellation is not None: return [ - pqm_pvalue(x_samples, y_samples, num_refs=num_refs, whiten=whiten) + pqm_pvalue(x_samples, y_samples, num_refs=num_refs, whiten=whiten, x_frac=x_frac) for _ in range(re_tessellation) ] - _, pvalue, _, _ = _pqm_test(x_samples, y_samples, num_refs, whiten) + _, pvalue, _, _ = _pqm_test(x_samples, y_samples, num_refs, whiten, x_frac) return pvalue @@ -105,22 +157,37 @@ def pqm_chi2( num_refs: int = 100, re_tessellation: Optional[int] = None, whiten: bool = False, + x_frac: Optional[float] = None, ): """ - Perform the PQM test of the null hypothesis that `x_samples` and `y_samples` are drawn form the same distribution. + Perform the PQM test of the null hypothesis that `x_samples` and `y_samples` + are drawn form the same distribution. Parameters ---------- x_samples : np.ndarray - Samples from the first distribution, test samples. Must have shape (N, *D) N is the number of x samples, and D is the dimensionality of the samples. + Samples from the first distribution, test samples. Must have shape (N, + *D) N is the number of x samples, and D is the dimensionality of the + samples. y_samples : np.ndarray - Samples from the second distribution, reference samples. Must have shape (M, *D) M is the number of y samples, and D is the dimensionality of the samples. + Samples from the second distribution, reference samples. Must have shape + (M, *D) M is the number of y samples, and D is the dimensionality of the + samples. num_refs : int - Number of reference samples to use. Note that these will be drawn from y_samples, and then removed from the y_samples array. + Number of reference samples to use. Note that these will be drawn from + y_samples, and then removed from the y_samples array. re_tessellation : Optional[int] - Number of times pqm_chi2 is called, re tesselating the space. No re_tessellation if None (default). + Number of times pqm_chi2 is called, re-tesselating the space. No + re_tessellation if None (default). whiten : bool - If True, whiten the samples by subtracting the mean and dividing by the standard deviation. + If True, whiten the samples by subtracting the mean and dividing by the + standard deviation. + x_frac : float + Fraction of x_samples to use as reference samples. ``x_frac = 1`` will + use only x_samples as reference samples, ``x_frac = 0`` will use only + y_samples as reference samples. Ideally, ``x_frac = len(x_samples) / + (len(x_samples) + len(y_samples))`` which is what is done for x_frac = + None (default). Returns ------- @@ -129,10 +196,10 @@ def pqm_chi2( """ if re_tessellation is not None: return [ - pqm_chi2(x_samples, y_samples, num_refs=num_refs, whiten=whiten) + pqm_chi2(x_samples, y_samples, num_refs=num_refs, whiten=whiten, x_frac=x_frac) for _ in range(re_tessellation) ] - chi2_stat, _, dof, _ = _pqm_test(x_samples, y_samples, num_refs, whiten) + chi2_stat, _, dof, _ = _pqm_test(x_samples, y_samples, num_refs, whiten, x_frac) if dof != num_refs - 1: # Rescale chi2 to new value which has the same cumulative probability if chi2_stat / dof < 10: @@ -141,4 +208,4 @@ def pqm_chi2( else: chi2_stat = chi2_stat * (num_refs - 1) / dof dof = num_refs - 1 - return chi2_stat \ No newline at end of file + return chi2_stat