From f94a56f8d2a2cae454f2c06f3ce55472de5a5b5a Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 8 Nov 2024 17:45:23 -0800 Subject: [PATCH] feat: change DeepNCM classifier to a decoder so it can be used with different model types. make small changes to DefaultClassifier forward_loss to pass label tensor when needed. update tests --- flair/models/__init__.py | 4 +- flair/models/deepncm_classification_model.py | 328 +++--------------- flair/nn/model.py | 11 +- .../functional/deepncm_trainer_plugin.py | 13 +- tests/models/test_deepncm_classifier.py | 61 ++-- 5 files changed, 103 insertions(+), 314 deletions(-) diff --git a/flair/models/__init__.py b/flair/models/__init__.py index bf3651078..d9fca4a70 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,4 @@ -from .deepncm_classification_model import DeepNCMClassifier +from .deepncm_classification_model import DeepNCMDecoder from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -38,5 +38,5 @@ "TextClassifier", "TextRegressor", "MultitaskModel", - "DeepNCMClassifier", + "DeepNCMDecoder", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py index b942e2891..ec3385a78 100644 --- a/flair/models/deepncm_classification_model.py +++ b/flair/models/deepncm_classification_model.py @@ -1,20 +1,15 @@ import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Literal, Optional import torch -from tqdm import tqdm import flair -from flair.data import Dictionary, Sentence -from flair.datasets import DataLoader, FlairDatapointDataset -from flair.embeddings import DocumentEmbeddings -from flair.embeddings.base import load_embeddings -from flair.nn import Classifier +from flair.data import Dictionary log = logging.getLogger("flair") -class DeepNCMClassifier(Classifier[Sentence]): +class DeepNCMDecoder(torch.nn.Module): """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. This model combines deep learning with the Nearest Class Mean (NCM) approach. @@ -32,47 +27,50 @@ class DeepNCMClassifier(Classifier[Sentence]): def __init__( self, - embeddings: DocumentEmbeddings, label_dictionary: Dictionary, - label_type: str, + embeddings_size: int, encoding_dim: Optional[int] = None, alpha: float = 0.9, mean_update_method: Literal["online", "condensation", "decay"] = "online", use_encoder: bool = True, - multi_label: bool = False, - multi_label_threshold: float = 0.5, - ): - """Initialize a DeepNCMClassifier. + multi_label: bool = False, # should get from the Model it belongs to + ) -> None: + """Initialize a DeepNCMDecoder. Args: - embeddings: Document embeddings to use for encoding text. - label_dictionary: Dictionary containing the label vocabulary. - label_type: The type of label to predict. encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). - alpha: The decay factor for updating class prototypes (default is 0.9). + alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). use_encoder: Whether to apply an encoder to the input embeddings (default is True). multi_label: Whether to predict multiple labels per sentence (default is False). - multi_label_threshold: The threshold for multi-label prediction (default is 0.5). """ + super().__init__() - self.embeddings = embeddings self.label_dictionary = label_dictionary - self._label_type = label_type + self._num_prototypes = len(label_dictionary) + self.alpha = alpha self.mean_update_method = mean_update_method self.use_encoder = use_encoder self.multi_label = multi_label - self.multi_label_threshold = multi_label_threshold - self.num_classes = len(label_dictionary) - self.embedding_dim = embeddings.embedding_length + + self.embedding_dim = embeddings_size if use_encoder: self.encoding_dim = encoding_dim or self.embedding_dim else: self.encoding_dim = self.embedding_dim + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False + ) + + self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) + self.to(flair.device) + self._validate_parameters() if self.use_encoder: @@ -84,22 +82,11 @@ def __init__( else: self.encoder = torch.nn.Sequential(torch.nn.Identity()) - self.loss_function = ( - torch.nn.BCEWithLogitsLoss(reduction="sum") - if self.multi_label - else torch.nn.CrossEntropyLoss(reduction="sum") - ) - - self.class_prototypes = torch.nn.Parameter( - torch.nn.functional.normalize(torch.randn(self.num_classes, self.encoding_dim)), requires_grad=False - ) - self.class_counts = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False) - self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) - self.prototype_update_counts = torch.zeros(self.num_classes).to(flair.device) + # all parameters will be pushed internally to the specified device self.to(flair.device) def _validate_parameters(self) -> None: - """Validate the input parameters.""" + """Validate that the input parameters have valid and compatible values.""" assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" assert self.mean_update_method in [ "online", @@ -108,26 +95,13 @@ def _validate_parameters(self) -> None: ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" assert self.encoding_dim > 0, "encoding_dim must be greater than 0" - def forward(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor: - """Encode the input sentences using embeddings and optional encoder. - - Args: - sentences: Input sentence or list of sentences. - - Returns: - torch.Tensor: Encoded representations of the input sentences. - """ - if not isinstance(sentences, list): - sentences = [sentences] - - self.embeddings.embed(sentences) - sentence_embeddings = torch.stack([sentence.get_embedding() for sentence in sentences]) - encoded_embeddings = self.encoder(sentence_embeddings) - - return encoded_embeddings + @property + def num_prototypes(self) -> int: + """The number of class prototypes.""" + return self.class_prototypes.size(0) def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: - """Calculate distances between encoded embeddings and class prototypes. + """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. Args: encoded_embeddings: Encoded representations of the input sentences. @@ -135,60 +109,7 @@ def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor Returns: torch.Tensor: Distances between encoded embeddings and class prototypes. """ - return torch.cdist(encoded_embeddings, self.class_prototypes) - - def forward_loss(self, data_points: List[Sentence]) -> Tuple[torch.Tensor, int]: - """Compute the loss for a batch of sentences. - - Args: - data_points: A list of sentences. - - Returns: - Tuple[torch.Tensor, int]: The total loss and the number of sentences. - """ - encoded_embeddings = self.forward(data_points) - labels = self._prepare_label_tensor(data_points) - distances = self._calculate_distances(encoded_embeddings) - loss = self.loss_function(-distances, labels) - self._calculate_prototype_updates(encoded_embeddings, labels) - - return loss, len(data_points) - - def _prepare_label_tensor(self, sentences: List[Sentence]) -> torch.Tensor: - """Prepare the label tensor for the given sentences. - - Args: - sentences: A list of sentences. - - Returns: - torch.Tensor: The label tensor for the given sentences. - """ - if self.multi_label: - return torch.tensor( - [ - [ - ( - 1 - if label - in [sentence_label.value for sentence_label in sentence.get_labels(self._label_type)] - else 0 - ) - for label in self.label_dictionary.get_items() - ] - for sentence in sentences - ], - dtype=torch.float, - device=flair.device, - ) - else: - return torch.tensor( - [ - self.label_dictionary.get_idx_for_item(sentence.get_label(self._label_type).value) - for sentence in sentences - ], - dtype=torch.long, - device=flair.device, - ) + return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: """Calculate updates for class prototypes based on the current batch. @@ -198,7 +119,7 @@ def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: labels: True labels for the input sentences. """ one_hot = ( - labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_classes).float() + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() ) updates = torch.matmul(one_hot.t(), encoded_embeddings) @@ -230,163 +151,25 @@ def update_prototypes(self) -> None: # Reset prototype updates self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) - self.prototype_update_counts = torch.zeros(self.num_classes, device=flair.device) - - def predict( - self, - sentences: Union[List[Sentence], Sentence], - mini_batch_size: int = 32, - return_probabilities_for_all_classes: bool = False, - verbose: bool = False, - label_name: Optional[str] = None, - return_loss: bool = False, - embedding_storage_mode: str = "none", - ) -> Union[List[Sentence], Tuple[float, int]]: - """Predict classes for a list of sentences. - - Args: - sentences: A list of sentences or a single sentence. - mini_batch_size: Size of mini batches during prediction. - return_probabilities_for_all_classes: Whether to return probabilities for all classes. - verbose: If True, show progress bar during prediction. - label_name: The name of the label to use for prediction. - return_loss: If True, compute and return loss. - embedding_storage_mode: The mode for storing embeddings ('none', 'cpu', or 'gpu'). - - Returns: - Union[List[Sentence], Tuple[float, int]]: - if return_loss is True, returns a tuple of total loss and total number of sentences; - otherwise, returns the list of sentences with predicted labels. - """ - with torch.no_grad(): - if not isinstance(sentences, list): - sentences = [sentences] - if not sentences: - return sentences - - label_name = label_name or self.label_type - Sentence.set_context_for_sentences(sentences) - - filtered_sentences = [sent for sent in sentences if len(sent) > 0] - reordered_sentences = sorted(filtered_sentences, key=len, reverse=True) - - if len(reordered_sentences) == 0: - return sentences - - dataloader = DataLoader( - dataset=FlairDatapointDataset(reordered_sentences), - batch_size=mini_batch_size, - ) - - if verbose: - progress_bar = tqdm(dataloader) - progress_bar.set_description("Predicting") - dataloader = progress_bar - - total_loss = 0.0 - total_sentences = 0 - - for batch in dataloader: - if not batch: - continue - - encoded_embeddings = self.forward(batch) - distances = self._calculate_distances(encoded_embeddings) - - if self.multi_label: - probabilities = torch.sigmoid(-distances) - else: - probabilities = torch.nn.functional.softmax(-distances, dim=1) + self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) - if return_loss: - labels = self._prepare_label_tensor(batch) - loss = self.loss_function(-distances, labels) - total_loss += loss.item() - total_sentences += len(batch) + def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of the decoder, which calculates the scores as prototype distances. - for sentence_index, sentence in enumerate(batch): - sentence.remove_labels(label_name) - - if self.multi_label: - for label_index, probability in enumerate(probabilities[sentence_index]): - if probability > self.multi_label_threshold or return_probabilities_for_all_classes: - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(label_name, label_value, probability.item()) - else: - predicted_idx = torch.argmax(probabilities[sentence_index]) - label_value = self.label_dictionary.get_item_for_index(predicted_idx.item()) - sentence.add_label(label_name, label_value, probabilities[sentence_index, predicted_idx].item()) - - if return_probabilities_for_all_classes: - for label_index, probability in enumerate(probabilities[sentence_index]): - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(f"{label_name}_all", label_value, probability.item()) - - for sentence in batch: - sentence.clear_embeddings(embedding_storage_mode) - - if return_loss: - return total_loss, total_sentences - return sentences - - def _get_state_dict(self) -> Dict[str, Any]: - """Get the state dictionary of the model. - - Returns: - Dict[str, Any]: The state dictionary containing model parameters and configuration. + :param embedded: Embedded representations of the input sentences. + :param label_tensor: True labels for the input sentences as a tensor. + :return: Scores as a tensor of distances to class prototypes. """ - model_state = { - "embeddings": self.embeddings.save_embeddings(), - "label_dictionary": self.label_dictionary, - "label_type": self.label_type, - "encoding_dim": self.encoding_dim, - "alpha": self.alpha, - "mean_update_method": self.mean_update_method, - "use_encoder": self.use_encoder, - "multi_label": self.multi_label, - "multi_label_threshold": self.multi_label_threshold, - "class_prototypes": self.class_prototypes.cpu(), - "class_counts": self.class_counts.cpu(), - "encoder": self.encoder.state_dict(), - } - return model_state - - @classmethod - def _init_model_with_state_dict(cls, state, **kwargs) -> "DeepNCMClassifier": - """Initialize the model from a state dictionary. + encoded_embeddings = self.encoder(embedded) - Args: - state: The state dictionary containing model parameters and configuration. - **kwargs: Additional keyword arguments for model initialization. + distances = self._calculate_distances(encoded_embeddings) - Returns: - DeepNCMClassifier: An instance of the model initialized with the given state. - """ - embeddings = state["embeddings"] - if isinstance(embeddings, dict): - embeddings = load_embeddings(embeddings) - - model = cls( - embeddings=embeddings, - label_dictionary=state["label_dictionary"], - label_type=state["label_type"], - encoding_dim=state["encoding_dim"], - alpha=state["alpha"], - mean_update_method=state["mean_update_method"], - use_encoder=state["use_encoder"], - multi_label=state.get("multi_label", False), - multi_label_threshold=state.get("multi_label_threshold", 0.5), - **kwargs, - ) + if label_tensor is not None: + self._calculate_prototype_updates(encoded_embeddings, label_tensor) - if "encoder" in state: - model.encoder.load_state_dict(state["encoder"]) - if "class_prototypes" in state: - model.class_prototypes.data = state["class_prototypes"].to(flair.device) - if "class_counts" in state: - model.class_counts.data = state["class_counts"].to(flair.device) + scores = -distances - return model + return scores def get_prototype(self, class_name: str) -> torch.Tensor: """Get the prototype vector for a given class name. @@ -407,15 +190,15 @@ def get_prototype(self, class_name: str) -> torch.Tensor: return self.class_prototypes[class_idx].clone() - def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> List[Tuple[str, float]]: - """Get the top_k closest prototype vectors to the given input vector using the configured distance metric. + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: + """Get the k closest prototype vectors to the given input vector using the configured distance metric. Args: input_vector (torch.Tensor): The input vector to compare against prototypes. top_k (int): The number of closest prototypes to return (default is 5). Returns: - List[Tuple[str, float]]: Each tuple contains (class_name, distance). + list[tuple[str, float]]: Each tuple contains (class_name, distance). """ if input_vector.dim() != 1: raise ValueError("Input vector must be a 1D tensor") @@ -434,22 +217,3 @@ def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> nearest_prototypes.append((class_name, value.item())) return nearest_prototypes - - @property - def label_type(self) -> str: - """Get the label type for this classifier.""" - return self._label_type - - def __str__(self) -> str: - """Get a string representation of the model. - - Returns: - str: A string describing the model architecture. - """ - return ( - f"DeepNCMClassifier(\n" - f" (embeddings): {self.embeddings}\n" - f" (encoder): {self.encoder}\n" - f" (prototypes): {self.class_prototypes.shape}\n" - f")" - ) diff --git a/flair/nn/model.py b/flair/nn/model.py index f670c969a..d4062c89c 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, List, Optional, Tuple, Union import torch.nn from torch import Tensor @@ -765,8 +765,11 @@ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # pass data points through network to get encoded data point tensor data_point_tensor = self._encode_data_points(sentences, data_points) - # decode - scores = self.decoder(data_point_tensor) + # decode, passing label tensor if needed, such as for prototype updates + if "label_tensor" in inspect.signature(self.decoder.forward).parameters: + scores = self.decoder(data_point_tensor, label_tensor) + else: + scores = self.decoder(data_point_tensor) # an optional masking step (no masking in most cases) scores = self._mask_scores(scores, data_points) @@ -801,7 +804,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ): + ) -> Optional[Union[List[DT], Tuple[float, int]]]: """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index 2c4c0ccb4..e5394debd 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,6 +1,7 @@ import torch -from flair.models import DeepNCMClassifier, MultitaskModel +from flair.models import MultitaskModel +from flair.models.deepncm_classification_model import DeepNCMDecoder from flair.trainers.plugins.base import TrainerPlugin @@ -11,7 +12,7 @@ class DeepNCMPlugin(TrainerPlugin): """ def _process_models(self, operation: str): - """Process updates for all DeepNCMClassifier models in the trainer. + """Process updates for all DeepNCMDecoder decoders in the trainer. Args: operation (str): The operation to perform ('condensation' or 'update') @@ -21,11 +22,11 @@ def _process_models(self, operation: str): models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] for sub_model in models: - if isinstance(sub_model, DeepNCMClassifier): - if operation == "condensation" and sub_model.mean_update_method == "condensation": - sub_model.class_counts.data = torch.ones_like(sub_model.class_counts) + if hasattr(sub_model, "decoder") and isinstance(sub_model.decoder, DeepNCMDecoder): + if operation == "condensation" and sub_model.decoder.mean_update_method == "condensation": + sub_model.decoder.class_counts.data = torch.ones_like(sub_model.decoder.class_counts) elif operation == "update": - sub_model.update_prototypes() + sub_model.decoder.update_prototypes() @TrainerPlugin.hook def after_training_epoch(self, **kwargs): diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py index 3b76b6c0b..b587a3314 100644 --- a/tests/models/test_deepncm_classifier.py +++ b/tests/models/test_deepncm_classifier.py @@ -4,14 +4,14 @@ from flair.data import Sentence from flair.datasets import ClassificationCorpus from flair.embeddings import TransformerDocumentEmbeddings -from flair.models import DeepNCMClassifier +from flair.models import DeepNCMDecoder, TextClassifier from flair.trainers import ModelTrainer from flair.trainers.plugins import DeepNCMPlugin from tests.model_test_utils import BaseModelTest -class TestDeepNCMClassifier(BaseModelTest): - model_cls = DeepNCMClassifier +class TestDeepNCMDecoder(BaseModelTest): + model_cls = TextClassifier train_label_type = "class" multiclass_prediction_labels = ["POSITIVE", "NEGATIVE"] training_args = { @@ -33,6 +33,7 @@ def multiclass_train_test_sentence(self): return Sentence("This movie was great!") def build_model(self, embeddings, label_dict, **kwargs): + model_args = { "embeddings": embeddings, "label_dictionary": label_dict, @@ -40,9 +41,27 @@ def build_model(self, embeddings, label_dict, **kwargs): "use_encoder": False, "encoding_dim": 64, "alpha": 0.95, + "mean_update_method": "online", } model_args.update(kwargs) - return self.model_cls(**model_args) + + deepncm_decoder = DeepNCMDecoder( + label_dictionary=model_args["label_dictionary"], + embeddings_size=model_args["embeddings"].embedding_length, + alpha=model_args["alpha"], + encoding_dim=model_args["encoding_dim"], + mean_update_method=model_args["mean_update_method"], + ) + + model = self.model_cls( + embeddings=model_args["embeddings"], + label_dictionary=model_args["label_dictionary"], + label_type=model_args["label_type"], + multi_label=model_args.get("multi_label", False), + decoder=deepncm_decoder, + ) + + return model @pytest.mark.integration() def test_train_load_use_classifier( @@ -76,24 +95,24 @@ def test_get_prototype(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict) - prototype = model.get_prototype(next(iter(label_dict.get_items()))) + prototype = model.decoder.get_prototype(next(iter(label_dict.get_items()))) assert isinstance(prototype, torch.Tensor) - assert prototype.shape == (model.encoding_dim,) + assert prototype.shape == (model.decoder.encoding_dim,) with pytest.raises(ValueError): - model.get_prototype("NON_EXISTENT_CLASS") + model.decoder.get_prototype("NON_EXISTENT_CLASS") def test_get_closest_prototypes(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict) - input_vector = torch.randn(model.encoding_dim) - closest_prototypes = model.get_closest_prototypes(input_vector, top_k=2) + input_vector = torch.randn(model.decoder.encoding_dim) + closest_prototypes = model.decoder.get_closest_prototypes(input_vector, top_k=2) assert len(closest_prototypes) == 2 assert all(isinstance(item, tuple) and len(item) == 2 for item in closest_prototypes) with pytest.raises(ValueError): - model.get_closest_prototypes(torch.randn(model.encoding_dim + 1)) + model.decoder.get_closest_prototypes(torch.randn(model.decoder.encoding_dim + 1)) def test_forward_loss(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) @@ -113,16 +132,16 @@ def test_mean_update_methods(self, corpus, embeddings, mean_update_method): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) - initial_prototypes = model.class_prototypes.clone() + initial_prototypes = model.decoder.class_prototypes.clone() sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): sentence.add_label(self.train_label_type, label) model.forward_loss(sentences) - model.update_prototypes() + model.decoder.update_prototypes() - assert not torch.all(torch.eq(initial_prototypes, model.class_prototypes)) + assert not torch.all(torch.eq(initial_prototypes, model.decoder.class_prototypes)) @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): @@ -133,17 +152,19 @@ def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): plugin = DeepNCMPlugin() plugin.attach_to(trainer) - initial_class_counts = model.class_counts.clone() - initial_prototypes = model.class_prototypes.clone() + initial_class_counts = model.decoder.class_counts.clone() + initial_prototypes = model.decoder.class_prototypes.clone() # Simulate training epoch plugin.after_training_epoch() if mean_update_method == "condensation": - assert torch.all(model.class_counts == 1), "Class counts should be 1 for condensation method after epoch" + assert torch.all( + model.decoder.class_counts == 1 + ), "Class counts should be 1 for condensation method after epoch" elif mean_update_method == "online": assert torch.all( - torch.eq(model.class_counts, initial_class_counts) + torch.eq(model.decoder.class_counts, initial_class_counts) ), "Class counts should not change for online method after epoch" # Simulate training batch @@ -154,14 +175,14 @@ def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): plugin.after_training_batch() assert not torch.all( - torch.eq(initial_prototypes, model.class_prototypes) + torch.eq(initial_prototypes, model.decoder.class_prototypes) ), "Prototypes should be updated after a batch" if mean_update_method == "condensation": assert torch.all( - model.class_counts >= 1 + model.decoder.class_counts >= 1 ), "Class counts should be >= 1 for condensation method after a batch" elif mean_update_method == "online": assert torch.all( - model.class_counts > initial_class_counts + model.decoder.class_counts > initial_class_counts ), "Class counts should increase for online method after a batch"