forked from flairNLP/flair
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
flairNLPGH-3496: Add OneClassClassifier model
- Loading branch information
1 parent
c6a2643
commit 52d42a6
Showing
2 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from pathlib import Path | ||
from typing import Any, Optional, Union, List, Tuple, Dict | ||
|
||
import flair | ||
import numpy as np | ||
import torch | ||
from flair.data import Dictionary, Sentence | ||
from flair.embeddings import TokenEmbeddings | ||
from flair.training_utils import store_embeddings | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class AnomalyDetector(flair.nn.Classifier[Sentence]): | ||
"""A one-class classifier created to serve as an is_resume model. | ||
It uses a reconstruction method based on autoencoders. The score is the reconstruction error from compressing | ||
and decompressing a document. See https://en.wikipedia.org/wiki/One-class_classification#Reconstruction_methods | ||
for more. If the score is LOWER than the threshold, a label will be added with value equal to the trained class. | ||
Otherwise, the value will be "<unk>". | ||
You must set the threshold after training by running model.threshold = model.calculate_threshold(corpus.dev). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
embeddings: TokenEmbeddings, | ||
label_dictionary: Dictionary, | ||
encoding_dim: int, | ||
label_type: str, | ||
threshold: Optional[float] = None, | ||
) -> None: | ||
""" | ||
Args: | ||
embeddings: Embeddings to use during training and prediction | ||
label_dictionary: The label to predict. Must contain exactly one class. | ||
label_type: name of the annotation_layer to be predicted in case a corpus has multiple annotations | ||
encoding_dim: The size of the compressed embedding | ||
threshold: The score that separates in-class from out-of-class | ||
""" | ||
super().__init__() | ||
self.embeddings = embeddings | ||
if len(label_dictionary) != 1: | ||
raise ValueError(f"label_dictionary must have exactly 1 class: {label_dictionary}") | ||
self.label_dictionary = label_dictionary | ||
self.label_value = label_dictionary.get_items()[0] | ||
self.encoding_dim = encoding_dim | ||
self._label_type = label_type | ||
self.threshold = threshold | ||
embedding_dim = embeddings.embedding_length | ||
self.encoder = torch.nn.Sequential( | ||
torch.nn.Linear(embedding_dim, encoding_dim), | ||
torch.nn.LeakyReLU(True), | ||
torch.nn.Linear(encoding_dim, encoding_dim // 2), | ||
torch.nn.LeakyReLU(True), | ||
torch.nn.Linear(encoding_dim // 2, encoding_dim // 4), | ||
torch.nn.LeakyReLU(True), | ||
) | ||
|
||
self.decoder = torch.nn.Sequential( | ||
torch.nn.Linear(encoding_dim // 4, encoding_dim // 2), | ||
torch.nn.LeakyReLU(True), | ||
torch.nn.Linear(encoding_dim // 2, encoding_dim), | ||
torch.nn.LeakyReLU(True), | ||
torch.nn.Linear(encoding_dim, embedding_dim), | ||
torch.nn.LeakyReLU(True), | ||
) | ||
|
||
self.cosine_sim = torch.nn.CosineSimilarity(dim=1) | ||
self.to(flair.device) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.encoder(x) | ||
x = self.decoder(x) | ||
return x | ||
|
||
def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: | ||
""" | ||
:return: Tuple[scalar tensor, num examples] | ||
""" | ||
if len(sentences) == 0: | ||
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0 | ||
sentence_tensor = self._sentences_to_tensor(sentences) | ||
reconstructed_sentence_tensor = self.forward(sentence_tensor) | ||
return self._loss(reconstructed_sentence_tensor, sentence_tensor).sum(), len(sentences) | ||
|
||
def predict( | ||
self, | ||
sentences: Union[List[Sentence], Sentence], | ||
mini_batch_size: int = 32, | ||
return_probabilities_for_all_classes: bool = False, | ||
verbose: bool = False, | ||
label_name: Optional[str] = None, | ||
return_loss=False, | ||
embedding_storage_mode="none", | ||
) -> Optional[torch.Tensor]: | ||
"""Predicts the class labels for the given sentences. The labels are directly added to the sentences. | ||
:param sentences: list of sentences to predict | ||
:param mini_batch_size: the amount of sentences that will be predicted within one batch (unimplemented) | ||
:param return_probabilities_for_all_classes: return probabilities for all classes instead of only best predicted (unimplemented) | ||
:param verbose: set to True to display a progress bar (unimplemented) | ||
:param return_loss: set to True to return loss | ||
:param label_name: set this to change the name of the label type that is predicted | ||
:param embedding_storage_mode: default is 'none' which is the best is most cases. | ||
Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory. | ||
:return: None. If return_loss is set, returns a scalar tensor | ||
""" | ||
if label_name is None: | ||
label_name = self.label_type | ||
|
||
with torch.no_grad(): | ||
# make sure it's a list | ||
if not isinstance(sentences, list): | ||
sentences = [sentences] | ||
|
||
# filter empty sentences | ||
sentences = [sentence for sentence in sentences if len(sentence) > 0] | ||
if len(sentences) == 0: | ||
return torch.tensor(0.0, requires_grad=True, device=flair.device) if return_loss else None | ||
|
||
sentence_tensor = self._sentences_to_tensor(sentences) | ||
reconstructed = self.forward(sentence_tensor) | ||
loss_tensor = self._loss(reconstructed, sentence_tensor) | ||
|
||
for sentence, loss in zip(sentences, loss_tensor.tolist()): | ||
sentence.remove_labels(label_name) | ||
if self.threshold is not None and loss < self.threshold: | ||
label_value = self.label_value | ||
else: | ||
label_value = "<unk>" | ||
sentence.add_label(typename=label_name, value=label_value, score=loss) | ||
|
||
store_embeddings(sentences, storage_mode=embedding_storage_mode) | ||
|
||
if return_loss: | ||
return loss_tensor.sum() | ||
|
||
@property | ||
def label_type(self) -> str: | ||
return self._label_type | ||
|
||
def _sentences_to_tensor(self, sentences: List[Sentence]) -> torch.Tensor: | ||
self.embeddings.embed(sentences) | ||
return torch.stack([sentence.embedding for sentence in sentences]) | ||
|
||
def _loss(self, predicted: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | ||
"""Return cosine similarity loss | ||
:param predicted: tensor of shape (batch_size, embedding_size) | ||
:param labels: tensor of shape (batch_size, embedding_size) | ||
:return: tensor of shape (batch_size) | ||
""" | ||
|
||
if labels.size(0) == 0: | ||
return torch.tensor(0.0, requires_grad=True, device=flair.device) | ||
|
||
return 1 - self.cosine_sim(predicted, labels) | ||
|
||
def _get_state_dict(self): | ||
"""Returns the state dictionary for this model.""" | ||
model_state = { | ||
**super()._get_state_dict(), | ||
"embeddings": self.embeddings.save_embeddings(use_state_dict=False), | ||
"label_dictionary": self.label_dictionary, | ||
"encoding_dim": self.encoding_dim, | ||
"label_type": self.label_type, | ||
"threshold": self.threshold, | ||
} | ||
|
||
return model_state | ||
|
||
@classmethod | ||
def _init_model_with_state_dict(cls, state, **kwargs): | ||
return super()._init_model_with_state_dict( | ||
state, | ||
embeddings=state.get("embeddings"), | ||
label_dictionary=state.get("label_dictionary"), | ||
encoding_dim=state.get("encoding_dim"), | ||
label_type=state.get("label_type"), | ||
threshold=state.get("threshold"), | ||
**kwargs, | ||
) | ||
|
||
@classmethod | ||
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "AnomalyDetector": | ||
from typing import cast | ||
|
||
return cast("AnomalyDetector", super().load(model_path=model_path)) | ||
|
||
def calculate_threshold(self, dataset: Dataset[Sentence]) -> float: | ||
"""Determine the score threshold to consider a Sentence in-class. This implementation returns the score | ||
at which 99.5% of `dataset` will be considered in-class. Intended for use-cases targeting nearly-perfect recall. | ||
""" | ||
|
||
def score(sentence: Sentence) -> float: | ||
sentence_tensor = self._sentences_to_tensor([sentence]) | ||
reconstructed = self.forward(sentence_tensor) | ||
loss_tensor = self._loss(reconstructed, sentence_tensor) | ||
return loss_tensor.tolist()[0] | ||
|
||
scores = [score(sentence) for sentence in dataset] | ||
threshold = np.quantile(scores, 0.995) | ||
return threshold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import pytest | ||
|
||
from flair.embeddings import TransformerWordEmbeddings | ||
from tests.model_test_utils import BaseModelTest | ||
|
||
|
||
class TestAnomalyDetector(BaseModelTest): | ||
@pytest.fixture() | ||
def embeddings(self): | ||
return TransformerWordEmbeddings("distilbert-base-uncased") |