diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py index 9069666..53b5a6b 100644 --- a/mmlearn/tasks/linear_probing.py +++ b/mmlearn/tasks/linear_probing.py @@ -1,41 +1,54 @@ """A Module for linear evaluation of pretrained encoders.""" +import inspect from contextlib import nullcontext from functools import partial -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import hydra -import inspect import lightning as L # noqa: N812 import torch -from lightning.fabric.utilities.cloud_io import _load as pl_load -from lightning_utilities.core.rank_zero import rank_zero_warn -from omegaconf import DictConfig +from hydra_zen import store from lightning.pytorch.utilities.types import OptimizerLRScheduler +from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn -from torchmetrics import MetricCollection, Accuracy, AUROC, Precision, Recall, F1Score -from hydra_zen import store +from torchmetrics import AUROC, Accuracy, F1Score, MetricCollection, Precision, Recall -from mmlearn.datasets.core import Modalities, find_matching_indices -from mmlearn.datasets.core.modalities import Modality -from mmlearn.tasks.hooks import EvaluationHooks +from mmlearn.datasets.core import Modalities -def extract_vision_encoder(encoder: Any, encoder_checkpoint_path: Optional[str]) -> nn.Module: +def extract_vision_encoder( + encoder: Any, encoder_checkpoint_path: Optional[str] +) -> nn.Module: + """ + Extract the vision encoder from a PyTorch Lightning model. + + Args: + encoder (Any): The encoder configuration or model to be instantiated. + encoder_checkpoint_path (Optional[str]): Path to the checkpoint file containing + the encoder's state_dict. + + Returns + ------- + nn.Module: The vision encoder module extracted and initialized. + """ model: L.LightningModule = hydra.utils.instantiate(encoder) if encoder_checkpoint_path is None: - rank_zero_warn("No encoder_checkpoint_path path was provided for linear evaluation.") + rank_zero_warn( + "No encoder_checkpoint_path path was provided for linear evaluation." + ) else: checkpoint = torch.load(encoder_checkpoint_path) - if 'state_dict' not in checkpoint: + if "state_dict" not in checkpoint: raise KeyError("'state_dict' not found in checkpoint") - state_dict = checkpoint['state_dict'] + state_dict = checkpoint["state_dict"] # Filter keys that are related to vision encoder encoder_keys = { k.replace("encoders.rgb.", "") if k.startswith("encoders.rgb") else k: v - for k, v in state_dict.items() if "encoders.rgb" in k + for k, v in state_dict.items() + if "encoders.rgb" in k } try: if encoder_keys: @@ -114,7 +127,7 @@ class LinearClassifierModule(L.LightningModule): def __init__( self, # encoder: torch.nn.Module, - encoder: nn.Module, + encoder: nn.Module, encoder_checkpoint_path: Optional[str], modality: str, num_output_features: int, @@ -139,8 +152,10 @@ def __init__( ) self.modality = modality - - self.encoder: nn.Module = extract_vision_encoder(encoder, encoder_checkpoint_path) + + self.encoder: nn.Module = extract_vision_encoder( + encoder, encoder_checkpoint_path + ) linear_layer = nn.Linear(num_output_features, num_classes) if pre_classifier_batch_norm: @@ -164,40 +179,50 @@ def __init__( if self.top_k_list is None: self.top_k_list = [1, 5] accuracy_metrics = { - f"top_{k}_accuracy": Accuracy(task=task, num_classes=num_classes, top_k=k) + f"top_{k}_accuracy": Accuracy( + task=task, num_classes=num_classes, top_k=k + ) for k in self.top_k_list } - + # Additional metrics for multiclass classification additional_metrics = { - "precision": Precision(task=task, num_classes=num_classes, average="macro"), + "precision": Precision( + task=task, num_classes=num_classes, average="macro" + ), "recall": Recall(task=task, num_classes=num_classes, average="macro"), - "f1_score": F1Score(task=task, num_classes=num_classes, average="macro"), - "auc": AUROC(task=task, num_classes=num_classes, average="macro") # AUROC for multiclass + "f1_score": F1Score( + task=task, num_classes=num_classes, average="macro" + ), + "auc": AUROC( + task=task, num_classes=num_classes, average="macro" + ), # AUROC for multiclass } elif task == "multilabel": # Accuracy and other metrics for multilabel classification accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)} - + # Additional metrics for multilabel classification additional_metrics = { - "precision": Precision(task=task, num_labels=num_classes, average="macro"), + "precision": Precision( + task=task, num_labels=num_classes, average="macro" + ), "recall": Recall(task=task, num_labels=num_classes, average="macro"), "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"), - "auc": AUROC(task=task, num_labels=num_classes) # AUC for multilabel + "auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel } else: # binary # Accuracy and other metrics for binary classification accuracy_metrics = {"accuracy": Accuracy(task=task)} - + # Additional metrics for binary classification additional_metrics = { "precision": Precision(task=task), "recall": Recall(task=task), "f1_score": F1Score(task=task), - "auc": AUROC(task=task) # AUROC for binary classification + "auc": AUROC(task=task), # AUROC for binary classification } # combine all metrics @@ -230,18 +255,16 @@ def _get_logits_and_labels( self, batch: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """Return the logits and labels for a batch of data.""" - x : torch.Tensor = batch + x: torch.Tensor = batch y = batch[Modalities.get_modality(self.modality).target] - + logits = self(x) return logits, y - def _compute_loss( - self, batch: Dict[str, Any] - ) -> Optional[torch.Tensor]: + def _compute_loss(self, batch: Dict[str, Any]) -> Optional[torch.Tensor]: if self.loss_fn is None: return None - + if self.freeze_encoder: self.encoder.eval() @@ -296,7 +319,6 @@ def validation_step( torch.Tensor The loss computed for the batch. """ - logits, y = self._get_logits_and_labels(batch) loss: torch.Tensor = self.loss_fn(logits, y) @@ -312,7 +334,6 @@ def on_validation_epoch_end(self) -> None: self.log_dict(val_metrics) self.valid_metrics.reset() - def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 """Configure the optimizer and learning rate scheduler.""" if self.optimizer is None: @@ -389,7 +410,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 if isinstance(extras, partial): # Extract the keywords from the partial object lr_scheduler_dict.update(extras.keywords) - + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict} lr_scheduler = self.lr_scheduler(optimizer)