Skip to content

Commit

Permalink
chore: 🚨 Fix local mypy run
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Dec 9, 2024
1 parent 70a1b54 commit 6e314aa
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 46 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[mypy]
ignore_missing_imports = True
allow_redefinition = True
42 changes: 21 additions & 21 deletions src/quac/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def check_requirements(
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
) -> Callable:
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
Expand Down Expand Up @@ -66,8 +66,8 @@ def make_counterfactual_dataset(
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class)
) -> List[Tuple[str, int, int]]:
"""Generates a list of samples of a form (path_to_sample, source_class, target_class)
for data organized in a counterfactual style directory.
The dataset is organized in the following way:
Expand Down Expand Up @@ -147,7 +147,7 @@ def make_paired_dataset(
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, str, int, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
"""Generates a list of samples of a form (path_to_sample, target_path, class_index, target_class_index).
See :class:`DatasetFolder` for details.
Expand Down Expand Up @@ -217,8 +217,8 @@ def make_paired_attribution_dataset(
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, str, int, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
) -> List[Tuple[str, str, str, int, int]]:
"""Generates a list of samples of a form (path_to_sample, path_to_cf, path_to_attr, source_class, target_class).
See :class:`DatasetFolder` for details.
Expand Down Expand Up @@ -302,8 +302,8 @@ def make_paired_attribution_dataset(
class Sample:
image: torch.Tensor
source_class_index: int
path: Path = None
source_class: str = None
path: Optional[Path] = None
source_class: Optional[str] = None


# TODO remove?
Expand All @@ -312,10 +312,10 @@ class CounterfactualSample:
counterfactual: torch.Tensor
target_class_index: int
source_class_index: int
path: Path = None
counterfactual_path: Path = None
source_class: str = None
target_class: str = None
path: Optional[Path] = None
counterfactual_path: Optional[Path] = None
source_class: Optional[str] = None
target_class: Optional[str] = None


@dataclass
Expand All @@ -324,10 +324,10 @@ class PairedSample:
counterfactual: torch.Tensor
source_class_index: int
target_class_index: int
path: Path = None
counterfactual_path: Path = None
source_class: str = None
target_class: str = None
path: Optional[Path] = None
counterfactual_path: Optional[Path] = None
source_class: Optional[str] = None
target_class: Optional[str] = None


@dataclass
Expand All @@ -337,11 +337,11 @@ class SampleWithAttribution:
counterfactual: torch.Tensor
source_class_index: int
target_class_index: int
path: Path = None
counterfactual_path: Path = None
source_class: str = None
target_class: str = None
attribution_path: Path = None
path: Optional[Path] = None
counterfactual_path: Optional[Path] = None
source_class: Optional[str] = None
target_class: Optional[str] = None
attribution_path: Optional[Path] = None


class PairedImageDataset(Dataset):
Expand Down
30 changes: 15 additions & 15 deletions src/quac/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Utilities for generating counterfactual images."""

from .model import LatentInferenceModel, ReferenceInferenceModel
from .model import LatentInferenceModel, ReferenceInferenceModel, InferenceModel
from .data import LabelFreePngFolder

import logging
from quac.training.classification import ClassifierWrapper
import torch
from torchvision import transforms
from typing import Union
from typing import Union, Optional

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,7 +85,7 @@ def load_stargan(
kind="latent",
single_output_encoder: bool = False,
final_activation: Union[str, None] = None,
) -> torch.nn.Module:
) -> InferenceModel:
"""
Load an inference version of the StarGANv2 model from a checkpoint.
Expand All @@ -104,7 +104,7 @@ def load_stargan(
the loaded inference model
"""
if kind == "reference":
latent_inference_model = ReferenceInferenceModel(
latent_inference_model: InferenceModel = ReferenceInferenceModel(
checkpoint_dir=latent_model_checkpoint_dir,
img_size=img_size,
input_dim=input_dim,
Expand All @@ -115,7 +115,7 @@ def load_stargan(
final_activation=final_activation,
)
else:
latent_inference_model = LatentInferenceModel(
latent_inference_model: InferenceModel = LatentInferenceModel( # type: ignore[no-redef]
checkpoint_dir=latent_model_checkpoint_dir,
img_size=img_size,
input_dim=input_dim,
Expand All @@ -124,13 +124,13 @@ def load_stargan(
num_domains=num_domains,
final_activation=final_activation,
)
latent_inference_model.load_checkpoint(checkpoint_iter)
latent_inference_model.load_checkpoint(checkpoint_iter) # type: ignore
latent_inference_model.eval()
return latent_inference_model


@torch.no_grad()
def get_counterfactual(
def get_counterfactual( # type: ignore
classifier,
latent_inference_model,
x,
Expand All @@ -140,13 +140,13 @@ def get_counterfactual(
batch_size=10,
device=None,
max_tries=100,
best_pred_so_far=None,
best_cf_so_far=None,
best_cf_path_so_far=None,
best_pred_so_far=Optional[torch.Tensor],
best_cf_so_far=Optional[torch.Tensor],
best_cf_path_so_far=Optional[str],
error_if_not_found=False,
return_path=False,
return_pred=False,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[Union[str, torch.Tensor]], Optional[torch.Tensor]]:
"""
Tries to find a counterfactual for the given sample, given the target.
It creates a batch, and returns one of the samples if it is classified correctly.
Expand Down Expand Up @@ -185,13 +185,13 @@ def get_counterfactual(
f"Not enough reference images, reducing max_tries to {max_tries}."
)
# Get a batch of reference images, starting from batch_size * max_tries, of size batch_size
ref_batch, ref_paths = zip(
ref_batch_tuples, ref_paths = zip(
*[
dataset_ref[i]
for i in range(batch_size * (max_tries - 1), batch_size * max_tries)
]
)
ref_batch = torch.stack(ref_batch)
ref_batch = torch.stack(ref_batch_tuples)
# Generate batch_size counterfactuals
xcf = latent_inference_model(
x_multiple.to(device),
Expand Down Expand Up @@ -254,7 +254,7 @@ def get_counterfactual(
if return_path and kind == "reference":
if return_pred:
return best_cf_so_far, best_cf_path_so_far, best_pred_so_far
return best_cf_so_far, best_cf_path_so_far
return best_cf_so_far, best_cf_path_so_far # type: ignore
if return_pred:
return best_cf_so_far, best_pred_so_far
return best_cf_so_far, best_pred_so_far # type: ignore
return best_cf_so_far
19 changes: 16 additions & 3 deletions src/quac/generate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,22 @@
)
from quac.training.checkpoint import CheckpointIO
import torch
from typing import Optional


class LatentInferenceModel(torch.nn.Module):
class InferenceModel(torch.nn.Module):
"""A superclass for inference models.
Useful for type-checking.
"""

# TODO add checkpoint loading to this class
def __init__(self) -> None:
super().__init__()
pass


class LatentInferenceModel(InferenceModel):
def __init__(
self,
checkpoint_dir,
Expand Down Expand Up @@ -56,7 +69,7 @@ def forward(self, x_src, y_trg):
return x_fake


class ReferenceInferenceModel(torch.nn.Module):
class ReferenceInferenceModel(InferenceModel):
def __init__(
self,
checkpoint_dir,
Expand All @@ -73,7 +86,7 @@ def __init__(
img_size, style_dim, input_dim=input_dim, final_activation=final_activation
)
if single_output_encoder:
style_encoder = SingleOutputStyleEncoder(
style_encoder: StyleEncoder = SingleOutputStyleEncoder(
img_size, style_dim, num_domains, input_dim=input_dim
)
else:
Expand Down
12 changes: 8 additions & 4 deletions src/quac/training/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,28 @@

try:
import wandb

wandb_available = True
except ImportError:
wandb = None
wandb_available = False

try:
from torch.utils.tensorboard import SummaryWriter

tensorboard_available = True
except ImportError:
SummaryWriter = None
tensorboard_available = False


class Logger:
def create(log_type, resume_iter=0, hparams={}, **kwargs):
if log_type == "wandb":
if wandb is None:
if not wandb_available:
raise ImportError("wandb is not installed.")
resume = "allow" if resume_iter > 0 else False
return WandBLogger(hparams=hparams, resume=resume, **kwargs)
elif log_type == "tensorboard":
if SummaryWriter is None:
if not tensorboard_available:
raise ImportError("Tensorboard is not available.")
purge_step = resume_iter if resume_iter > 0 else None
return TensorboardLogger(hparams=hparams, purge_step=purge_step, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/quac/training/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self.optims[net] = torch.optim.Adam(
params=self.nets[net].parameters(),
lr=f_lr if net == "mapping_network" else lr,
betas=[beta1, beta2],
betas=(beta1, beta2),
weight_decay=weight_decay,
)

Expand Down Expand Up @@ -125,7 +125,7 @@ def latent_dim(self):
latent_dim = self.nets.mapping_network.module.latent_dim
return latent_dim

def train(
def train( # type: ignore
self,
loader,
resume_iter: int = 0,
Expand Down
2 changes: 1 addition & 1 deletion src/quac/training/stargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def forward(self, x, y):
return s


class SingleOutputStyleEncoder(nn.Module):
class SingleOutputStyleEncoder(StyleEncoder, nn.Module):
def __init__(
self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512, input_dim=3
):
Expand Down

0 comments on commit 6e314aa

Please sign in to comment.