Skip to content

Commit

Permalink
Fixed pre-commit issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Negiiiin committed Dec 19, 2024
1 parent bfec0cf commit 023ca9c
Showing 1 changed file with 58 additions and 37 deletions.
95 changes: 58 additions & 37 deletions mmlearn/tasks/linear_probing.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,54 @@
"""A Module for linear evaluation of pretrained encoders."""

import inspect
from contextlib import nullcontext
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import hydra
import inspect
import lightning as L # noqa: N812
import torch
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning_utilities.core.rank_zero import rank_zero_warn
from omegaconf import DictConfig
from hydra_zen import store
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import nn
from torchmetrics import MetricCollection, Accuracy, AUROC, Precision, Recall, F1Score
from hydra_zen import store
from torchmetrics import AUROC, Accuracy, F1Score, MetricCollection, Precision, Recall

from mmlearn.datasets.core import Modalities, find_matching_indices
from mmlearn.datasets.core.modalities import Modality
from mmlearn.tasks.hooks import EvaluationHooks
from mmlearn.datasets.core import Modalities


def extract_vision_encoder(encoder: Any, encoder_checkpoint_path: Optional[str]) -> nn.Module:
def extract_vision_encoder(
encoder: Any, encoder_checkpoint_path: Optional[str]
) -> nn.Module:
"""
Extract the vision encoder from a PyTorch Lightning model.
Args:
encoder (Any): The encoder configuration or model to be instantiated.
encoder_checkpoint_path (Optional[str]): Path to the checkpoint file containing
the encoder's state_dict.
Returns
-------
nn.Module: The vision encoder module extracted and initialized.
"""
model: L.LightningModule = hydra.utils.instantiate(encoder)
if encoder_checkpoint_path is None:
rank_zero_warn("No encoder_checkpoint_path path was provided for linear evaluation.")
rank_zero_warn(
"No encoder_checkpoint_path path was provided for linear evaluation."
)
else:
checkpoint = torch.load(encoder_checkpoint_path)
if 'state_dict' not in checkpoint:
if "state_dict" not in checkpoint:
raise KeyError("'state_dict' not found in checkpoint")

state_dict = checkpoint['state_dict']
state_dict = checkpoint["state_dict"]

# Filter keys that are related to vision encoder
encoder_keys = {
k.replace("encoders.rgb.", "") if k.startswith("encoders.rgb") else k: v
for k, v in state_dict.items() if "encoders.rgb" in k
for k, v in state_dict.items()
if "encoders.rgb" in k
}
try:
if encoder_keys:
Expand Down Expand Up @@ -114,7 +127,7 @@ class LinearClassifierModule(L.LightningModule):
def __init__(
self,
# encoder: torch.nn.Module,
encoder: nn.Module,
encoder: nn.Module,
encoder_checkpoint_path: Optional[str],
modality: str,
num_output_features: int,
Expand All @@ -139,8 +152,10 @@ def __init__(
)

self.modality = modality

self.encoder: nn.Module = extract_vision_encoder(encoder, encoder_checkpoint_path)

self.encoder: nn.Module = extract_vision_encoder(
encoder, encoder_checkpoint_path
)

linear_layer = nn.Linear(num_output_features, num_classes)
if pre_classifier_batch_norm:
Expand All @@ -164,40 +179,50 @@ def __init__(
if self.top_k_list is None:
self.top_k_list = [1, 5]
accuracy_metrics = {
f"top_{k}_accuracy": Accuracy(task=task, num_classes=num_classes, top_k=k)
f"top_{k}_accuracy": Accuracy(
task=task, num_classes=num_classes, top_k=k
)
for k in self.top_k_list
}

# Additional metrics for multiclass classification
additional_metrics = {
"precision": Precision(task=task, num_classes=num_classes, average="macro"),
"precision": Precision(
task=task, num_classes=num_classes, average="macro"
),
"recall": Recall(task=task, num_classes=num_classes, average="macro"),
"f1_score": F1Score(task=task, num_classes=num_classes, average="macro"),
"auc": AUROC(task=task, num_classes=num_classes, average="macro") # AUROC for multiclass
"f1_score": F1Score(
task=task, num_classes=num_classes, average="macro"
),
"auc": AUROC(
task=task, num_classes=num_classes, average="macro"
), # AUROC for multiclass
}

elif task == "multilabel":
# Accuracy and other metrics for multilabel classification
accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)}

# Additional metrics for multilabel classification
additional_metrics = {
"precision": Precision(task=task, num_labels=num_classes, average="macro"),
"precision": Precision(
task=task, num_labels=num_classes, average="macro"
),
"recall": Recall(task=task, num_labels=num_classes, average="macro"),
"f1_score": F1Score(task=task, num_labels=num_classes, average="macro"),
"auc": AUROC(task=task, num_labels=num_classes) # AUC for multilabel
"auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel
}

else: # binary
# Accuracy and other metrics for binary classification
accuracy_metrics = {"accuracy": Accuracy(task=task)}

# Additional metrics for binary classification
additional_metrics = {
"precision": Precision(task=task),
"recall": Recall(task=task),
"f1_score": F1Score(task=task),
"auc": AUROC(task=task) # AUROC for binary classification
"auc": AUROC(task=task), # AUROC for binary classification
}

# combine all metrics
Expand Down Expand Up @@ -230,18 +255,16 @@ def _get_logits_and_labels(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return the logits and labels for a batch of data."""
x : torch.Tensor = batch
x: torch.Tensor = batch
y = batch[Modalities.get_modality(self.modality).target]

logits = self(x)
return logits, y

def _compute_loss(
self, batch: Dict[str, Any]
) -> Optional[torch.Tensor]:
def _compute_loss(self, batch: Dict[str, Any]) -> Optional[torch.Tensor]:
if self.loss_fn is None:
return None

if self.freeze_encoder:
self.encoder.eval()

Expand Down Expand Up @@ -296,7 +319,6 @@ def validation_step(
torch.Tensor
The loss computed for the batch.
"""

logits, y = self._get_logits_and_labels(batch)

loss: torch.Tensor = self.loss_fn(logits, y)
Expand All @@ -312,7 +334,6 @@ def on_validation_epoch_end(self) -> None:
self.log_dict(val_metrics)
self.valid_metrics.reset()


def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912
"""Configure the optimizer and learning rate scheduler."""
if self.optimizer is None:
Expand Down Expand Up @@ -389,7 +410,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912
if isinstance(extras, partial):
# Extract the keywords from the partial object
lr_scheduler_dict.update(extras.keywords)

return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}

lr_scheduler = self.lr_scheduler(optimizer)
Expand Down

0 comments on commit 023ca9c

Please sign in to comment.