Skip to content

Commit

Permalink
Merge pull request #264 from VectorInstitute/dbe/mypy_follow_imports_on
Browse files Browse the repository at this point in the history
Turning follow imports on in mypy for better coverage
  • Loading branch information
emersodb authored Oct 26, 2024
2 parents 53431b7 + 0770f1b commit f366fa3
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 12 deletions.
11 changes: 6 additions & 5 deletions examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fl4health.model_bases.autoencoders_base import VariationalAe
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.server.base_server import FlServerWithCheckpointing
from fl4health.strategies.fedprox import FedProx
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -50,8 +50,9 @@ def main(config: Dict[str, Any]) -> None:
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedProx(
# Server performs simple FedAveraging as its server-side optimization strategy and potentially adapts the
# FedProx proximal weight mu
strategy = FedAvgWithAdaptiveConstraint(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
Expand All @@ -62,8 +63,8 @@ def main(config: Dict[str, Any]) -> None:
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(model),
adaptive_proximal_weight=config["adaptive_proximal_weight"],
proximal_weight=config["proximal_weight"],
adapt_loss_weight=config["adapt_proximal_weight"],
initial_loss_weight=config["initial_proximal_weight"],
)

server = FlServerWithCheckpointing(
Expand Down
5 changes: 3 additions & 2 deletions examples/fedopt_example/client_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

nltk.download("punkt")
nltk.download("punkt_tab")


class LabelEncoder:
Expand All @@ -22,7 +22,8 @@ def __init__(self, classes: List[str], label_to_class: Dict[int, str], class_to_
@staticmethod
def encoder_from_dataframe(df: pd.DataFrame, class_column: str) -> "LabelEncoder":
categories = df[class_column].astype("category")
label_to_class = dict(set(zip(categories.cat.codes, categories)))
categories_str = [str(category) for category in categories.to_list()]
label_to_class = dict(set(zip(categories.cat.codes, categories_str)))
class_to_label = {category: label for label, category in label_to_class.items()}
classes = categories.unique().tolist()
return LabelEncoder(classes, label_to_class, class_to_label)
Expand Down
2 changes: 1 addition & 1 deletion fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@ def maybe_progress_bar(self, iterable: Iterable) -> Iterable:
return iterable
else:
# Create a clean looking tqdm instance that matches the flwr logging
kwargs = {
kwargs: Any = {
"leave": True,
"ascii": " >=",
# "desc": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']} ",
Expand Down
6 changes: 4 additions & 2 deletions fl4health/feature_alignment/string_columns_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

Expand All @@ -13,7 +15,7 @@ class TextMulticolumnTransformer(BaseEstimator, TransformerMixin):
def __init__(self, transformer: TextFeatureTransformer):
self.transformer = transformer

def fit(self, X: pd.DataFrame, y: pd.DataFrame = None) -> "TextMulticolumnTransformer":
def fit(self, X: pd.DataFrame, y: Optional[pd.DataFrame] = None) -> "TextMulticolumnTransformer":
joined_X = X.apply(lambda x: " ".join(x), axis=1)
self.transformer.fit(joined_X)
return self
Expand All @@ -32,7 +34,7 @@ class TextColumnTransformer(BaseEstimator, TransformerMixin):
def __init__(self, transformer: TextFeatureTransformer):
self.transformer = transformer

def fit(self, X: pd.DataFrame, y: pd.DataFrame = None) -> "TextColumnTransformer":
def fit(self, X: pd.DataFrame, y: Optional[pd.DataFrame] = None) -> "TextColumnTransformer":
assert isinstance(X, pd.DataFrame) and X.shape[1] == 1
self.transformer.fit(X[X.columns[0]])
return self
Expand Down
1 change: 0 additions & 1 deletion fl4health/server/model_merge_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float
Tuple[History, float]: The first element of the tuple is a History object containing the aggregated
metrics returned from the clients. Tuple also contains elapsed time in seconds for round.
"""

self.reports_manager.report({"host_type": "server", "fit_start": datetime.datetime.now()})

history = History()
Expand Down
44 changes: 43 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,11 +1,53 @@
[mypy]
mypy_path=.
follow_imports = normal
ignore_missing_imports = True
ignore_missing_imports = False
install_types = True
pretty = True
non_interactive = True
disallow_untyped_defs = True
no_implicit_optional = True
check_untyped_defs = True
exclude = (venv|research/gemini/*)

[mypy-dp_accounting.*]
ignore_missing_imports = True

[mypy-sklearn.*]
ignore_missing_imports = True

[mypy-cyclops.*]
ignore_missing_imports = True

[mypy-batchgenerators.*]
ignore_missing_imports = True

[mypy-nnunetv2.*]
ignore_missing_imports = True

[mypy-flamby.*]
ignore_missing_imports = True

[mypy-opacus.*]
ignore_missing_imports = True

[mypy-qpth.*]
ignore_missing_imports = True

[mypy-nltk.*]
ignore_missing_imports = True

[mypy-datasets.*]
ignore_missing_imports = True

[mypy-transformers.*]
ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True

[mypy-torchvision.*]
ignore_missing_imports = True

[mypy-efficientnet_pytorch.*]
ignore_missing_imports = True

0 comments on commit f366fa3

Please sign in to comment.