Skip to content

Commit

Permalink
flairNLPGH-3496: Add OneClassClassifier model
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffpicard committed Jul 12, 2024
1 parent c6a2643 commit 52d42a6
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 0 deletions.
203 changes: 203 additions & 0 deletions flair/models/anomaly_detector_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from pathlib import Path
from typing import Any, Optional, Union, List, Tuple, Dict

import flair
import numpy as np
import torch
from flair.data import Dictionary, Sentence
from flair.embeddings import TokenEmbeddings
from flair.training_utils import store_embeddings
from torch.utils.data import Dataset


class AnomalyDetector(flair.nn.Classifier[Sentence]):
"""A one-class classifier created to serve as an is_resume model.
It uses a reconstruction method based on autoencoders. The score is the reconstruction error from compressing
and decompressing a document. See https://en.wikipedia.org/wiki/One-class_classification#Reconstruction_methods
for more. If the score is LOWER than the threshold, a label will be added with value equal to the trained class.
Otherwise, the value will be "<unk>".
You must set the threshold after training by running model.threshold = model.calculate_threshold(corpus.dev).
"""

def __init__(
self,
embeddings: TokenEmbeddings,
label_dictionary: Dictionary,
encoding_dim: int,
label_type: str,
threshold: Optional[float] = None,
) -> None:
"""
Args:
embeddings: Embeddings to use during training and prediction
label_dictionary: The label to predict. Must contain exactly one class.
label_type: name of the annotation_layer to be predicted in case a corpus has multiple annotations
encoding_dim: The size of the compressed embedding
threshold: The score that separates in-class from out-of-class
"""
super().__init__()
self.embeddings = embeddings
if len(label_dictionary) != 1:
raise ValueError(f"label_dictionary must have exactly 1 class: {label_dictionary}")
self.label_dictionary = label_dictionary
self.label_value = label_dictionary.get_items()[0]
self.encoding_dim = encoding_dim
self._label_type = label_type
self.threshold = threshold
embedding_dim = embeddings.embedding_length
self.encoder = torch.nn.Sequential(
torch.nn.Linear(embedding_dim, encoding_dim),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim, encoding_dim // 2),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim // 2, encoding_dim // 4),
torch.nn.LeakyReLU(True),
)

self.decoder = torch.nn.Sequential(
torch.nn.Linear(encoding_dim // 4, encoding_dim // 2),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim // 2, encoding_dim),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim, embedding_dim),
torch.nn.LeakyReLU(True),
)

self.cosine_sim = torch.nn.CosineSimilarity(dim=1)
self.to(flair.device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
x = self.decoder(x)
return x

def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:
"""
:return: Tuple[scalar tensor, num examples]
"""
if len(sentences) == 0:
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0
sentence_tensor = self._sentences_to_tensor(sentences)
reconstructed_sentence_tensor = self.forward(sentence_tensor)
return self._loss(reconstructed_sentence_tensor, sentence_tensor).sum(), len(sentences)

def predict(
self,
sentences: Union[List[Sentence], Sentence],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
label_name: Optional[str] = None,
return_loss=False,
embedding_storage_mode="none",
) -> Optional[torch.Tensor]:
"""Predicts the class labels for the given sentences. The labels are directly added to the sentences.
:param sentences: list of sentences to predict
:param mini_batch_size: the amount of sentences that will be predicted within one batch (unimplemented)
:param return_probabilities_for_all_classes: return probabilities for all classes instead of only best predicted (unimplemented)
:param verbose: set to True to display a progress bar (unimplemented)
:param return_loss: set to True to return loss
:param label_name: set this to change the name of the label type that is predicted
:param embedding_storage_mode: default is 'none' which is the best is most cases.
Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory.
:return: None. If return_loss is set, returns a scalar tensor
"""
if label_name is None:
label_name = self.label_type

with torch.no_grad():
# make sure it's a list
if not isinstance(sentences, list):
sentences = [sentences]

# filter empty sentences
sentences = [sentence for sentence in sentences if len(sentence) > 0]
if len(sentences) == 0:
return torch.tensor(0.0, requires_grad=True, device=flair.device) if return_loss else None

sentence_tensor = self._sentences_to_tensor(sentences)
reconstructed = self.forward(sentence_tensor)
loss_tensor = self._loss(reconstructed, sentence_tensor)

for sentence, loss in zip(sentences, loss_tensor.tolist()):
sentence.remove_labels(label_name)
if self.threshold is not None and loss < self.threshold:
label_value = self.label_value
else:
label_value = "<unk>"
sentence.add_label(typename=label_name, value=label_value, score=loss)

store_embeddings(sentences, storage_mode=embedding_storage_mode)

if return_loss:
return loss_tensor.sum()

@property
def label_type(self) -> str:
return self._label_type

def _sentences_to_tensor(self, sentences: List[Sentence]) -> torch.Tensor:
self.embeddings.embed(sentences)
return torch.stack([sentence.embedding for sentence in sentences])

def _loss(self, predicted: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Return cosine similarity loss
:param predicted: tensor of shape (batch_size, embedding_size)
:param labels: tensor of shape (batch_size, embedding_size)
:return: tensor of shape (batch_size)
"""

if labels.size(0) == 0:
return torch.tensor(0.0, requires_grad=True, device=flair.device)

return 1 - self.cosine_sim(predicted, labels)

def _get_state_dict(self):
"""Returns the state dictionary for this model."""
model_state = {
**super()._get_state_dict(),
"embeddings": self.embeddings.save_embeddings(use_state_dict=False),
"label_dictionary": self.label_dictionary,
"encoding_dim": self.encoding_dim,
"label_type": self.label_type,
"threshold": self.threshold,
}

return model_state

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
return super()._init_model_with_state_dict(
state,
embeddings=state.get("embeddings"),
label_dictionary=state.get("label_dictionary"),
encoding_dim=state.get("encoding_dim"),
label_type=state.get("label_type"),
threshold=state.get("threshold"),
**kwargs,
)

@classmethod
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "AnomalyDetector":
from typing import cast

return cast("AnomalyDetector", super().load(model_path=model_path))

def calculate_threshold(self, dataset: Dataset[Sentence]) -> float:
"""Determine the score threshold to consider a Sentence in-class. This implementation returns the score
at which 99.5% of `dataset` will be considered in-class. Intended for use-cases targeting nearly-perfect recall.
"""

def score(sentence: Sentence) -> float:
sentence_tensor = self._sentences_to_tensor([sentence])
reconstructed = self.forward(sentence_tensor)
loss_tensor = self._loss(reconstructed, sentence_tensor)
return loss_tensor.tolist()[0]

scores = [score(sentence) for sentence in dataset]
threshold = np.quantile(scores, 0.995)
return threshold
10 changes: 10 additions & 0 deletions tests/models/test_anomaly_detector_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from flair.embeddings import TransformerWordEmbeddings
from tests.model_test_utils import BaseModelTest


class TestAnomalyDetector(BaseModelTest):
@pytest.fixture()
def embeddings(self):
return TransformerWordEmbeddings("distilbert-base-uncased")

0 comments on commit 52d42a6

Please sign in to comment.