Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric refactor #69

Merged
merged 12 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/apfl_example/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Parameters that describe server
n_server_rounds: 25 # The number of rounds to run FL
n_server_rounds: 5 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 3 # The number of clients in the FL experiment
Expand Down
3 changes: 1 addition & 2 deletions examples/federated_eval_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fl4health.clients.evaluate_client import EvaluateClient
from fl4health.utils.load_data import load_cifar10_test_data
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric, MetricMeterType
from fl4health.utils.metrics import Accuracy, Metric


class CifarClient(EvaluateClient):
Expand All @@ -26,7 +26,6 @@ def __init__(
device=device,
model_checkpoint_path=model_checkpoint_path,
loss_meter_type=LossMeterType.AVERAGE,
metric_meter_type=MetricMeterType.AVERAGE,
)

def initialize_global_model(self, config: Config) -> Optional[nn.Module]:
Expand Down
43 changes: 15 additions & 28 deletions examples/fedopt_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from pathlib import Path
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -11,43 +11,29 @@
from torch.utils.data import DataLoader

from examples.fedopt_example.client_data import LabelEncoder, Vocabulary, construct_dataloaders
from examples.fedopt_example.metrics import CustomMetricMeter, MetricMeter
from examples.fedopt_example.metrics import CompoundMetric
from examples.models.lstm_model import LSTM
from fl4health.checkpointing.checkpointer import TorchCheckpointer
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.losses import LossMeter, LossMeterType
from fl4health.utils.metrics import MetricMeterManager
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Metric


class NewsClassifierClient(BasicClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super(BasicClient, self).__init__(data_path, device)
self.checkpointer = checkpointer
self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)
self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)

self.model: nn.Module
self.optimizer: torch.optim.Optimizer

self.train_loader: DataLoader
self.val_loader: DataLoader
self.num_train_samples: int
self.num_val_samples: int
self.learning_rate: float
super().__init__(data_path, metrics, device, loss_meter_type, checkpointer)
self.weight_matrix: torch.Tensor
self.vocabulary: Vocabulary
self.label_encoder: LabelEncoder
self.batch_size: int

# Need to track total_steps across rounds for WANDB reporting
self.total_steps: int = 0

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
sequence_length = self.narrow_config_type(config, "sequence_length", int)
self.batch_size = self.narrow_config_type(config, "batch_size", int)
Expand All @@ -74,21 +60,22 @@ def get_model(self, config: Config) -> nn.Module:
def setup_client(self, config: Config) -> None:
self.vocabulary = Vocabulary.from_json(self.narrow_config_type(config, "vocabulary", str))
self.label_encoder = LabelEncoder.from_json(self.narrow_config_type(config, "label_encoder", str))
# Define mapping from prediction key to meter to pass to MetricMeterManager constructor for train and val
train_key_to_meter_map: Dict[str, MetricMeter] = {"prediction": CustomMetricMeter(self.label_encoder)}
self.train_metric_meter_mngr = MetricMeterManager(train_key_to_meter_map)
val_key_to_meter_map: Dict[str, MetricMeter] = {"prediction": CustomMetricMeter(self.label_encoder)}
self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map)
# Since the label_encoder is required for CompundMetric but it is not available until after we receive
# it from the Server, we pass it to the CompoundMetric through the CompoundMetric._setup method once its
# available
for metric in self.metrics:
if isinstance(metric, CompoundMetric):
metric._setup(self.label_encoder)
super().setup_client(config)

def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:
def predict(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
# While this isn't optimal, this is a good example of a custom predict function to manipulate the predictions
assert isinstance(self.model, LSTM)
h0, c0 = self.model.init_hidden(self.batch_size)
h0 = h0.to(self.device)
c0 = c0.to(self.device)
preds = self.model(input, (h0, c0))
return {"prediction": preds}
return {"prediction": preds}, {}


if __name__ == "__main__":
Expand All @@ -99,5 +86,5 @@ def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:
# Load model and data
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = NewsClassifierClient(data_path, DEVICE)
client = NewsClassifierClient(data_path, [CompoundMetric("")], DEVICE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we give this metric a name?

fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client)
24 changes: 15 additions & 9 deletions examples/fedopt_example/metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
from logging import INFO
from typing import Dict, List
from typing import Dict, List, Optional

import torch
from flwr.common.logger import log
from flwr.common.typing import Metrics
from sklearn.metrics import confusion_matrix

from examples.fedopt_example.client_data import LabelEncoder
from fl4health.utils.metrics import MetricMeter
from fl4health.utils.metrics import Metric


class Outcome:
Expand Down Expand Up @@ -75,20 +75,26 @@ def compute_metrics(self) -> Metrics:
return metrics


class CustomMetricMeter(MetricMeter):
def __init__(self, label_encoder: LabelEncoder) -> None:
class CompoundMetric(Metric):
def __init__(self, name: str) -> None:
"""
This class is used to compute metrics associated with the AG's News task. There are a number of classes and
we want to accumulate a bunch of statistics all at once to facilitate the computation of a number of different
metrics for this problem. As such, we define our own MetricMeter and bypass the standard metric meter
implementations, which calculate separate metrics individually.
metrics for this problem. As such, we define our own Metric class and bypass the standard SimpleMetric class,
which calculate separate metrics individually.

Args:
label_encoder (LabelEncoder): This class is used to determine the mapping of integers to label names for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe transfer this comment about label_encoder to the setup method below?

the AG's news task.
name (str): The name of the compound metric.
"""
super().__init__(name)
self.true_preds = 0
self.total_preds = 0
self.classes: List[str]
self.label_to_class: Dict[int, str]
self.n_classes: int
self.outcome_dict: Dict[str, Outcome]

def _setup(self, label_encoder: LabelEncoder) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically the _ prefix is reserved for protected or private methods. That is, methods that are exclusively called within the class itself, rather than externally. All that is to say, I would recommend dropping the _ based on the way this is being used 🙂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good call! I initially was aiming for it to be internal but realized we have to call it externally so I forgot to change it back

self.classes = label_encoder.classes
self.outcome_dict = self._initialize_outcomes(self.classes)
self.label_to_class = label_encoder.label_to_class
Expand Down Expand Up @@ -117,7 +123,7 @@ def update(self, input: torch.Tensor, target: torch.Tensor) -> None:
self.outcome_dict[true_class].false_negative += count
self.outcome_dict[pred_class].false_positive += count

def compute(self) -> Metrics:
def compute(self, name: Optional[str]) -> Metrics:
sum_f1 = 0.0
results: Metrics = {"total_preds": self.total_preds, "true_preds": self.true_preds}
log_string = ""
Expand Down
4 changes: 1 addition & 3 deletions examples/partial_weight_exchange_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from examples.partial_weight_exchange_example.client_data import construct_dataloaders
from fl4health.clients.dynamic_weight_exchange_client import DynamicWeightExchangeClient
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric, MetricMeterType
from fl4health.utils.metrics import Accuracy, Metric


class TransformerPartialExchangeClient(DynamicWeightExchangeClient):
Expand All @@ -27,14 +27,12 @@ def __init__(
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
)
self.test_loader: DataLoader

Expand Down
57 changes: 23 additions & 34 deletions fl4health/clients/apfl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import torch
from flwr.common.typing import Config
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import TorchCheckpointer
from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.utils.losses import Losses, LossMeter, LossMeterType
from fl4health.utils.metrics import Metric, MetricMeter, MetricMeterManager, MetricMeterType
from fl4health.utils.losses import Losses, LossMeterType
from fl4health.utils.metrics import Metric


class ApflClient(BasicClient):
Expand All @@ -21,42 +20,15 @@ def __init__(
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super(BasicClient, self).__init__(data_path, device)
self.metrics = metrics
self.checkpointer = checkpointer
self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)
self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)

# Define mapping from prediction key to meter to pass to MetricMeterManager constructor for train and val
train_key_to_meter_map = {
"local": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter_local"),
"global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter_global"),
"personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter_personal"),
}
self.train_metric_meter_mngr = MetricMeterManager(train_key_to_meter_map)
val_key_to_meter_map = {
"local": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_local"),
"global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_global"),
"personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_personal"),
}
self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map)

self.train_loader: DataLoader
self.val_loader: DataLoader
self.num_train_samples: int
self.num_val_samples: int
super().__init__(data_path, metrics, device, loss_meter_type, checkpointer)

self.model: ApflModule
self.learning_rate: float
self.optimizer: torch.optim.Optimizer
self.local_optimizer: torch.optim.Optimizer

# Need to track total_steps across rounds for WANDB reporting
self.total_steps: int = 0

def is_start_of_local_training(self, step: int) -> bool:
return step == 0

Expand Down Expand Up @@ -90,9 +62,9 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses,

# Personal predictions are generated as a convex combination of the output
# of local and global models
preds = self.predict(input)
preds, features = self.predict(input)
# Parameters of local model are updated to minimize loss of personalized model
losses = self.compute_loss(preds, target)
losses = self.compute_loss(preds, features, target)
losses.backward.backward()
self.local_optimizer.step()

Expand All @@ -102,7 +74,24 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses,
def get_parameter_exchanger(self, config: Config) -> FixedLayerExchanger:
return FixedLayerExchanger(self.model.layers_to_exchange())

def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> Losses:
def compute_loss(
self,
preds: Union[torch.Tensor, Dict[str, torch.Tensor]],
features: Dict[str, torch.Tensor],
target: torch.Tensor,
) -> Losses:
"""
Computes loss given predictions of the model and ground truth data.

Args:
preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name.
features: (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.
target: (torch.Tensor): Ground truth data to evaluate predictions against.

Returns:
Losses: Object containing checkpoint loss, backward loss and additional losses indexed by name.
Additional losses include global and local losses.
"""
assert isinstance(preds, dict)
personal_loss = self.criterion(preds["personal"], target)
global_loss = self.criterion(preds["global"], target)
Expand Down
Loading