Skip to content

Commit

Permalink
Interpretable detector and deep kernels detector (#306)
Browse files Browse the repository at this point in the history
* Refactor pt GaussianRBF and introduce DeepKernel

* NME functionality

Docstrings etc to follow.

* Fix tensorflow implementation

* Update GaussianRBF and DeepKernel to init sigma correctly

* Typo

* Minor refactor to nme implementations

* Minor correction

* Fix kernels to remain compatible with other uses

* Upload temp nme example scripts

* Add todo for kernel initialisation

* Allow model to continue training across test batches

* Remove nme stuff

* Minor bug fix

* Allow ClassifierDrift to return model

* Minor bug fix

* Update trainer and ClfDrift to print ma and allow reg term

* Implementation of a spot-the-diff drift detector

* Remove temp NME scripts

* Remove all references of compile_kwargs

* Remove exponentiation of clf coefficients

* Allow initialisation of diffs to be specified

* Batch computation of kernel matrices

* Fix batchwise kernel mat computation

* Further fixes to batch comp of kernel matrices

* Implementation of learnt kernel detector

* Update ClfDrift.predict() docstring

* Update STDDetector to handle lists. Also docstrings

* retrain_from_scratch arg bux fix

* Change clf example nb to gen learned nb

* Add reg_loss_fn to trainer docstring

* Some type hints

* Remove dead line

* Sub n=len(x) for readability

* Transfer x_cur, x_ref to device after preprocess_fn

* Simmilarity -> similarity in about 100 places

* Update type hint to Union[np.ndarray, list]

* Make clear in STDDrift docstring diff is featurewise

* Remove the #TODO about GRBF init

* Learnt -> learned everywhere

* Minor bug fix

* Indentation bug fix

* Update STDDetector to handle List[Any]'

* Spot-the-diff drift example notebook

* Minor changes to clf/lk example notebook

* Bug fix around dead kwargs

* Add kwargs to STDDetector meta dict

* Bug fix around meta

* Tests for STDDetector and LKDetector

* Type hints

* Init optimizer inside trainer

* STDDiff notebook typos

* Docs method notebook for learned kernel detectror

* Docs methods notebook for spot-the-diff detector

* Some renaming and resourcing

* Credit Jitkrittum

* Correct typo

* Nblink for stddrift example notebook

* For GaussianRBF cast second input to dtype of first

* Kernel tests

* Deep kernel docstrings

* Fix batched kernel matrix and test it

* Update how batch_compute_kernel_matrix is used

* Sort STD notebooks and Jitkrittum credit

* Fix typos

* Catch some edge cases when using multiple folds

* Doc stuff

* Minor changes to docs

* n_folds bug fix

* Coffs -> coeffs

* Satisfy mypy

* Satisfy flake8

* update to .gitignore

* Minor typos

* Ignore future annotations!

* More future annotations stuff

* Another go at future annotations

* Remove future annotations
  • Loading branch information
ojcobb authored Aug 16, 2021
1 parent 0355f9b commit e965165
Show file tree
Hide file tree
Showing 41 changed files with 3,839 additions and 181 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,6 @@ venv.bak/
# mypy
.mypy_cache/

# temp scripts for debugging
examples/temp_*

8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ The following tables show the advised use cases for each algorithm. The column *
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Kolmogorov-Smirnov ||| ||| ||
| Maximum Mean Discrepancy ||| |||| |
| Learned Kernel MMD ||| ||| | |
| Least-Squares Density Difference ||| |||| |
| Chi-Squared || | | || ||
| Mixed-type tabular data || | | || ||
| Classifier |||||| | |
| Spot-the-diff |||||| ||
| Classifier Uncertainty |||||| | |
| Regressor Uncertainty |||||| | |

Expand Down Expand Up @@ -210,6 +212,9 @@ Check the example notebooks (e.g. [CIFAR10](https://docs.seldon.io/projects/alib
- [Maximum Mean Discrepancy](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/mmddrift.html) ([Gretton et al, 2012](http://jmlr.csail.mit.edu/papers/v13/gretton12a.html))
- Example: [CIFAR10](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_mmd_cifar10.html), [molecular graphs](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_mol.html), [movie reviews](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_text_imdb.html), [Amazon reviews](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_text_amazon.html)

- [Learned Kernel MMD](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/learnedkernel.html)([Liu et al, 2020](https://arxiv.org/abs/2002.09116))
- Example: [CIFAR10](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_clf_cifar10.html)

- [Chi-Squared](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/chisquaredrift.html)
- Example: [Income Prediction](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_chi2ks_adult.html)

Expand All @@ -219,6 +224,9 @@ Check the example notebooks (e.g. [CIFAR10](https://docs.seldon.io/projects/alib
- [Classifier](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/classifierdrift.html) ([Lopez-Paz and Oquab, 2017](https://openreview.net/forum?id=SJkXfE5xx))
- Example: [CIFAR10](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_clf_cifar10.html), [Amazon reviews](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_text_amazon.html)

- [Spot-the-diff](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/spotthediffdrift.html) (adaptation of [Jitkrittum et al, 2016](https://arxiv.org/abs/1605.06796))
- Example [MNIST and Wine quality](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/spot_the_diff_mnist_win.html)

- [Classifier and Regressor Uncertainty](https://docs.seldon.io/projects/alibi-detect/en/latest/methods/modeluncdrift.html)
- Example: [CIFAR10 and Wine](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_model_unc_cifar10_wine.html), [molecular graphs](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_mol.html)

Expand Down
6 changes: 5 additions & 1 deletion alibi_detect/cd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .chisquare import ChiSquareDrift
from .classifier import ClassifierDrift
from .ks import KSDrift
from .learned_kernel import LearnedKernelDrift
from .lsdd import LSDDDrift
from .lsdd_online import LSDDDriftOnline
from .spot_the_diff import SpotTheDiffDrift
from .mmd import MMDDrift
from .mmd_online import MMDDriftOnline
from .model_uncertainty import ClassifierUncertaintyDrift, RegressorUncertaintyDrift
Expand All @@ -12,11 +14,13 @@
"ChiSquareDrift",
"ClassifierDrift",
"KSDrift",
"LearnedKernelDrift",
"LSDDDrift",
"LSDDDriftOnline",
"MMDDrift",
"MMDDriftOnline",
"TabularDrift",
"ClassifierUncertaintyDrift",
"RegressorUncertaintyDrift"
"RegressorUncertaintyDrift",
"SpotTheDiffDrift"
]
194 changes: 190 additions & 4 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
binarize_preds: bool = False,
train_size: Optional[float] = .75,
n_folds: Optional[int] = None,
retrain_from_scratch: bool = True,
seed: int = 0,
data_type: Optional[str] = None
) -> None:
Expand Down Expand Up @@ -63,6 +64,9 @@ def __init__(
on all the out-of-fold predictions. This allows to leverage all the reference and test data
for drift detection at the expense of longer computation. If both `train_size` and `n_folds`
are specified, `n_folds` is prioritized.
retrain_from_scratch
Whether the classifier should be retrained from scratch for each set of test data or whether
it should instead continue training from where it left off on the previous set.
seed
Optional random seed for fold selection.
data_type
Expand All @@ -79,6 +83,9 @@ def __init__(
if preds_type not in ['probs', 'logits']:
raise ValueError("'preds_type' should be 'probs' or 'logits'")

if n_folds is not None and n_folds > 1 and not retrain_from_scratch:
raise ValueError("If using multiple folds the model must be retrained from scratch for each fold.")

# optionally already preprocess reference data
self.p_val = p_val
if preprocess_x_ref and isinstance(preprocess_fn, Callable): # type: ignore
Expand All @@ -98,6 +105,7 @@ def __init__(
self.skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
else:
self.train_size, self.skf = train_size, None
self.retrain_from_scratch = retrain_from_scratch

# set metadata
self.meta['detector_type'] = 'offline'
Expand Down Expand Up @@ -203,8 +211,8 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, n
pass

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
return_distance: bool = True, return_probs: bool = True) \
-> Dict[Dict[str, str], Dict[str, Union[int, float]]]:
return_distance: bool = True, return_probs: bool = True, return_model: bool = True) \
-> Dict[str, Dict[str, Union[str, int, float, Callable]]]:
"""
Predict whether a batch of data has drifted from the reference data.
Expand All @@ -220,14 +228,16 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
return_probs
Whether to return the instance level classifier probabilities for the reference and test data
(0=reference data, 1=test data).
return_model
Whether to return the updated model trained to discriminate reference and test instances.
Returns
-------
Dictionary containing 'meta' and 'data' dictionaries.
'meta' has the model's metadata.
'data' contains the drift prediction and optionally the performance of the classifier
'data' contains the drift prediction and optionally the p-value, performance of the classifier
relative to its expectation under the no-change null, the out-of-fold classifier model
prediction probabilities on the reference and test data.
prediction probabilities on the reference and test data, and the trained model.
"""
# compute drift scores
p_val, dist, probs_ref, probs_test = self.score(x)
Expand All @@ -252,6 +262,182 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
if return_probs:
cd['data']['probs_ref'] = probs_ref
cd['data']['probs_test'] = probs_test
if return_model:
cd['data']['model'] = self.model # type: ignore
return cd


class BaseLearnedKernelDrift(BaseDetector):
def __init__(
self,
x_ref: Union[np.ndarray, list],
p_val: float = .05,
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
n_permutations: int = 100,
train_size: Optional[float] = .75,
retrain_from_scratch: bool = True,
data_type: Optional[str] = None
) -> None:
"""
Base class for the learned kernel-based drift detector.
Parameters
----------
x_ref
Data used as reference distribution.
p_val
p-value used for the significance of the test.
preprocess_x_ref
Whether to already preprocess and store the reference data.
update_x_ref
Reference data can optionally be updated to the last n instances seen by the detector
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
for reservoir sampling {'reservoir_sampling': n} is passed.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
n_permutations
The number of permutations to use in the permutation test once the MMD has been computed.
train_size
Optional fraction (float between 0 and 1) of the dataset used to train the kernel.
The drift is detected on `1 - train_size`. Cannot be used in combination with `n_folds`.
retrain_from_scratch
Whether the kernel should be retrained from scratch for each set of test data or whether
it should instead continue training from where it left off on the previous set.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
"""
super().__init__()

if p_val is None:
logger.warning('No p-value set for the drift threshold. Need to set it to detect data drift.')

# optionally already preprocess reference data
self.p_val = p_val
if preprocess_x_ref and isinstance(preprocess_fn, Callable): # type: ignore
self.x_ref = preprocess_fn(x_ref)
else:
self.x_ref = x_ref
self.preprocess_x_ref = preprocess_x_ref
self.update_x_ref = update_x_ref
self.preprocess_fn = preprocess_fn
self.n = len(x_ref) # type: ignore

self.n_permutations = n_permutations
self.train_size = train_size
self.retrain_from_scratch = retrain_from_scratch

# set metadata
self.meta['detector_type'] = 'offline'
self.meta['data_type'] = data_type

def preprocess(self, x: Union[np.ndarray, list]) -> Tuple[Union[np.ndarray, list], Union[np.ndarray, list]]:
"""
Data preprocessing before computing the drift scores.
Parameters
----------
x
Batch of instances.
Returns
-------
Preprocessed reference data and new instances.
"""
if isinstance(self.preprocess_fn, Callable): # type: ignore
x = self.preprocess_fn(x)
x_ref = self.x_ref if self.preprocess_x_ref else self.preprocess_fn(self.x_ref)
return x_ref, x
else:
return self.x_ref, x

def get_splits(self, x_ref: Union[np.ndarray, list], x: Union[np.ndarray, list]) \
-> Tuple[Tuple[Union[np.ndarray, list], Union[np.ndarray, list]],
Tuple[Union[np.ndarray, list], Union[np.ndarray, list]]]:
"""
Split reference and test data into two splits -- one of which to learn test locations
and parameters and one to use for tests.
Parameters
----------
x_ref
Data used as reference distribution.
x
Batch of instances.
Returns
-------
Tuple containing split train data and tuple containing split test data
"""

n_ref, n_cur = len(x_ref), len(x)
perm_ref, perm_cur = np.random.permutation(n_ref), np.random.permutation(n_cur)
idx_ref_tr, idx_ref_te = perm_ref[:int(n_ref*self.train_size)], perm_ref[int(n_ref*self.train_size):]
idx_cur_tr, idx_cur_te = perm_cur[:int(n_cur*self.train_size)], perm_cur[int(n_cur*self.train_size):]

if isinstance(x_ref, np.ndarray):
x_ref_tr, x_ref_te = x_ref[idx_ref_tr], x_ref[idx_ref_te]
x_cur_tr, x_cur_te = x[idx_cur_tr], x[idx_cur_te]
elif isinstance(x, list):
x_ref_tr, x_ref_te = [x_ref[_] for _ in idx_ref_tr], [x_ref[_] for _ in idx_ref_te]
x_cur_tr, x_cur_te = [x[_] for _ in idx_cur_tr], [x[_] for _ in idx_cur_te]
else:
raise TypeError(f'x needs to be of type np.ndarray or list and not {type(x)}.')

return (x_ref_tr, x_cur_tr), (x_ref_te, x_cur_te)

@abstractmethod
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
pass

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
return_distance: bool = True, return_kernel: bool = True) \
-> Dict[Dict[str, str], Dict[str, Union[int, float, Callable]]]:
"""
Predict whether a batch of data has drifted from the reference data.
Parameters
----------
x
Batch of instances.
return_p_val
Whether to return the p-value of the permutation test.
return_distance
Whether to return the MMD metric between the new batch and reference data.
return_kernel
Whether to return the updated kernel trained to discriminate reference and test instances.
Returns
-------
Dictionary containing 'meta' and 'data' dictionaries.
'meta' has the detector's metadata.
'data' contains the drift prediction and optionally the p-value, threshold, MMD metric and
trained kernel.
"""
# compute drift scores
p_val, dist, dist_permutations = self.score(x)
drift_pred = int(p_val < self.p_val)

# compute distance threshold
idx_threshold = int(self.p_val * len(dist_permutations))
distance_threshold = np.sort(dist_permutations)[::-1][idx_threshold]

# update reference dataset
if isinstance(self.update_x_ref, dict) and self.preprocess_fn is not None and self.preprocess_x_ref:
x = self.preprocess_fn(x)
self.x_ref = update_reference(self.x_ref, x, self.n, self.update_x_ref)
# used for reservoir sampling
self.n += len(x) # type: ignore

# populate drift dict
cd = concept_drift_dict()
cd['meta'] = self.meta
cd['data']['is_drift'] = drift_pred
if return_p_val:
cd['data']['p_val'] = p_val
cd['data']['threshold'] = self.p_val
if return_distance:
cd['data']['distance'] = dist
cd['data']['distance_threshold'] = distance_threshold
if return_kernel:
cd['data']['kernel'] = self.kernel # type: ignore
return cd


Expand Down
20 changes: 14 additions & 6 deletions alibi_detect/cd/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def __init__(
preprocess_fn: Optional[Callable] = None,
preds_type: str = 'probs',
binarize_preds: bool = False,
reg_loss_fn: Callable = (lambda model: 0),
train_size: Optional[float] = .75,
n_folds: Optional[int] = None,
retrain_from_scratch: bool = True,
seed: int = 0,
optimizer: Optional[Callable] = None,
learning_rate: float = 1e-3,
Expand Down Expand Up @@ -67,6 +69,8 @@ def __init__(
binarize_preds
Whether to test for discrepency on soft (e.g. probs/logits) model predictions directly
with a K-S test or binarise to 0-1 prediction errors and apply a binomial test.
reg_loss_fn
The regularisation term reg_loss_fn(model) is added to the loss function being optimized.
train_size
Optional fraction (float between 0 and 1) of the dataset used to train the classifier.
The drift is detected on `1 - train_size`. Cannot be used in combination with `n_folds`.
Expand All @@ -75,6 +79,9 @@ def __init__(
on all the out-of-fold instances. This allows to leverage all the reference and test data
for drift detection at the expense of longer computation. If both `train_size` and `n_folds`
are specified, `n_folds` is prioritized.
retrain_from_scratch
Whether the classifier should be retrained from scratch for each set of test data or whether
it should instead continue training from where it left off on the previous set.
seed
Optional random seed for fold selection.
optimizer
Expand Down Expand Up @@ -125,7 +132,6 @@ def __init__(
kwargs.update({'dataset': TFDataset})
self._detector = ClassifierDriftTF(*args, **kwargs) # type: ignore
else:
kwargs.pop('compile_kwargs', None)
if dataset is None:
kwargs.update({'dataset': TorchDataset})
if dataloader is None:
Expand All @@ -134,8 +140,8 @@ def __init__(
self.meta = self._detector.meta

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
return_distance: bool = True, return_probs: bool = True) \
-> Dict[Dict[str, str], Dict[str, Union[int, float]]]:
return_distance: bool = True, return_probs: bool = True, return_model: bool = True) \
-> Dict[str, Dict[str, Union[str, int, float, Callable]]]:
"""
Predict whether a batch of data has drifted from the reference data.
Expand All @@ -151,13 +157,15 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
return_probs
Whether to return the instance level classifier probabilities for the reference and test data
(0=reference data, 1=test data).
return_model
Whether to return the updated model trained to discriminate reference and test instances.
Returns
-------
Dictionary containing 'meta' and 'data' dictionaries.
'meta' has the model's metadata.
'data' contains the drift prediction and optionally the performance of the classifier
'data' contains the drift prediction and optionally the p-value, performance of the classifier
relative to its expectation under the no-change null, the out-of-fold classifier model
prediction probabilities on the reference and test data.
prediction probabilities on the reference and test data, and the trained model.
"""
return self._detector.predict(x, return_p_val, return_distance, return_probs)
return self._detector.predict(x, return_p_val, return_distance, return_probs, return_model)
Loading

0 comments on commit e965165

Please sign in to comment.