Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typecheck(cli/_helpers.py) #1667

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions lightly/cli/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
from __future__ import annotations

import os
from typing import Any

import hydra
import torch
Expand All @@ -17,7 +20,7 @@
from lightly.utils.version_compare import version_compare


def cpu_count():
def cpu_count() -> int | None:
"""Returns the number of CPUs which are present in the system.

This number is not equivalent to the number of available CPUs to the process.
Expand All @@ -26,21 +29,26 @@ def cpu_count():
return os.cpu_count()


def fix_input_path(path):
def fix_input_path(path: str) -> str:
"""Fix broken relative paths."""
if not os.path.isabs(path):
path = utils.to_absolute_path(path)
return path


def fix_hydra_arguments(config_path: str = "config", config_name: str = "config"):
def fix_hydra_arguments(
config_path: str = "config", config_name: str = "config"
) -> dict[str, str | None]:
"""Helper to make hydra arugments adaptive to installed hydra version

Hydra introduced the `version_base` argument in version 1.2.0
We use this helper to provide backwards compatibility to older hydra verisons.
"""

hydra_args = {"config_path": config_path, "config_name": config_name}
hydra_args: dict[str, str | None] = {
"config_path": config_path,
"config_name": config_name,
}

try:
if version_compare(hydra.__version__, "1.2.0") >= 0:
Expand All @@ -53,13 +61,13 @@ def fix_hydra_arguments(config_path: str = "config", config_name: str = "config"
return hydra_args


def is_url(checkpoint):
def is_url(checkpoint: str) -> bool:
"""Check whether the checkpoint is a url or not."""
is_url = "https://storage.googleapis.com" in checkpoint
return is_url


def get_ptmodel_from_config(model):
def get_ptmodel_from_config(model: dict[str, Any]) -> tuple[str, str]:
"""Get a pre-trained model from the lightly model zoo."""
key = model["name"]
key += "/simclr"
Expand All @@ -72,10 +80,14 @@ def get_ptmodel_from_config(model):
return "", key


def load_state_dict_from_url(url, map_location=None):
def load_state_dict_from_url(
url: str, map_location: torch.device | None = None
) -> dict[str, torch.Tensor | None]:
"""Try to load the checkopint from the given url."""
try:
state_dict = torch.hub.load_state_dict_from_url(url, map_location=map_location)
state_dict: dict[str, torch.Tensor] = torch.hub.load_state_dict_from_url(
url, map_location=map_location
)
return state_dict
except Exception:
print("Not able to load state dict from %s" % (url))
Expand All @@ -89,10 +101,15 @@ def load_state_dict_from_url(url, map_location=None):

# in this case downloading the pre-trained model was not possible
# notify the user and return

return {"state_dict": None}


def _maybe_expand_batchnorm_weights(model_dict, state_dict, num_splits):
def _maybe_expand_batchnorm_weights(
model_dict: dict[str, torch.Tensor],
state_dict: dict[str, torch.Tensor],
num_splits: int,
) -> dict[str, torch.Tensor]:
"""Expands the weights of the BatchNorm2d to the size of SplitBatchNorm."""
running_mean = "running_mean"
running_var = "running_var"
Expand All @@ -116,7 +133,9 @@ def _maybe_expand_batchnorm_weights(model_dict, state_dict, num_splits):
return state_dict


def _filter_state_dict(state_dict, remove_model_prefix_offset: int = 1):
def _filter_state_dict(
state_dict: dict[str, torch.Tensor], remove_model_prefix_offset: int = 1
) -> dict[str, torch.Tensor]:
"""Makes the state_dict compatible with the model.

Prevents unexpected key error when loading PyTorch-Lightning checkpoints.
Expand All @@ -141,7 +160,9 @@ def _filter_state_dict(state_dict, remove_model_prefix_offset: int = 1):
return new_state_dict


def _fix_projection_head_keys(state_dict):
def _fix_projection_head_keys(
state_dict: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Makes the state_dict compatible with the refactored projection heads.

TODO: Remove once the models are refactored and the old checkpoints were
Expand Down Expand Up @@ -173,12 +194,12 @@ def _fix_projection_head_keys(state_dict):


def load_from_state_dict(
model,
state_dict,
model: nn.Module,
state_dict: dict[str, torch.Tensor],
strict: bool = True,
apply_filter: bool = True,
num_splits: int = 0,
):
) -> None:
"""Loads the model weights from the state dictionary."""

# step 1: filter state dict
Expand All @@ -196,7 +217,9 @@ def load_from_state_dict(
model.load_state_dict(state_dict, strict=strict)


def get_model_from_config(cfg, is_cli_call: bool = False) -> SelfSupervisedEmbedding:
def get_model_from_config(
cfg: dict[str, Any], is_cli_call: bool = False
) -> SelfSupervisedEmbedding:
checkpoint = cfg["checkpoint"]
if torch.cuda.is_available():
device = torch.device("cuda")
Expand Down Expand Up @@ -233,5 +256,5 @@ def get_model_from_config(cfg, is_cli_call: bool = False) -> SelfSupervisedEmbed
if state_dict is not None:
load_from_state_dict(model, state_dict)

encoder = SelfSupervisedEmbedding(model, None, None, None)
encoder = SelfSupervisedEmbedding(model, None, None, None) # type: ignore
return encoder
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ exclude = '''(?x)(
lightly/cli/config/get_config.py |
lightly/cli/train_cli.py |
lightly/cli/_cli_simclr.py |
lightly/cli/_helpers.py |
lightly/loss/ntx_ent_loss.py |
lightly/loss/vicreg_loss.py |
lightly/loss/tico_loss.py |
Expand Down
Loading