-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create Fixed Requirements File for FLamby, Update Dynamic Weight Exchanger and FedOpt Example #68
Changes from 6 commits
4ea6d1b
01a3ed3
6191978
89d6c7b
a926aed
f210506
fb0aa84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,176 +1,93 @@ | ||
import argparse | ||
from logging import INFO | ||
from pathlib import Path | ||
from typing import Dict, Tuple | ||
from typing import Dict, Optional, Tuple | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
from flwr.common.logger import log | ||
from flwr.common.typing import Config, Metrics, NDArrays, Scalar | ||
from flwr.common.typing import Config | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from examples.fedopt_example.client_data import LabelEncoder, Vocabulary, construct_dataloaders | ||
from examples.fedopt_example.metrics import ClientMetrics | ||
from examples.fedopt_example.metrics import CustomMetricMeter, MetricMeter | ||
from examples.models.lstm_model import LSTM | ||
from fl4health.clients.numpy_fl_client import NumpyFlClient | ||
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger | ||
|
||
|
||
def train( | ||
model: nn.Module, | ||
train_loader: DataLoader, | ||
epochs: int, | ||
label_encoder: LabelEncoder, | ||
weight_matrix: torch.Tensor, | ||
device: torch.device = torch.device("cpu"), | ||
) -> Metrics: | ||
"""Train the network on the training set.""" | ||
model.train() | ||
criterion = torch.nn.CrossEntropyLoss(weight=weight_matrix) | ||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.001) | ||
for epoch in range(epochs): | ||
running_loss = 0.0 | ||
n_batches = len(train_loader) | ||
|
||
assert train_loader.batch_size is not None | ||
assert isinstance(model, LSTM) | ||
h0, c0 = model.init_hidden(train_loader.batch_size) | ||
h0 = h0.to(device) | ||
c0 = c0.to(device) | ||
|
||
epoch_metrics = ClientMetrics(label_encoder) | ||
|
||
for batch_index, (data, labels) in enumerate(train_loader): | ||
data, labels = data.to(device), labels.to(device) | ||
optimizer.zero_grad() | ||
out = model(data, (h0, c0)) | ||
loss = criterion(out, labels) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
running_loss += loss.item() | ||
_, predicted = torch.max(out.data, 1) | ||
|
||
# Report some batch loss statistics every so often to track decrease | ||
if batch_index % 100 == 0: | ||
log(INFO, f"Batch Index {batch_index} of {n_batches}, Batch loss: {loss.item()}") | ||
epoch_metrics.update_performance(predicted, labels) | ||
|
||
log_str = epoch_metrics.summarize() | ||
# Local client logging of epoch results. | ||
log( | ||
INFO, | ||
f"Epoch: {epoch}, Client Training Loss: {running_loss/n_batches}\nClient Training Metrics:{log_str}", | ||
) | ||
return epoch_metrics.results | ||
|
||
|
||
def validate( | ||
model: nn.Module, | ||
validation_loader: DataLoader, | ||
label_encoder: LabelEncoder, | ||
device: torch.device = torch.device("cpu"), | ||
) -> Tuple[float, Metrics]: | ||
"""Validate the network on the entire validation set.""" | ||
model.eval() | ||
criterion = torch.nn.CrossEntropyLoss() | ||
loss = 0.0 | ||
|
||
assert validation_loader.batch_size is not None | ||
assert isinstance(model, LSTM) | ||
h0, c0 = model.init_hidden(validation_loader.batch_size) | ||
h0 = h0.to(device) | ||
c0 = c0.to(device) | ||
|
||
epoch_metrics = ClientMetrics(label_encoder) | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
n_batches = len(validation_loader) | ||
for data, labels in validation_loader: | ||
data, labels = data.to(device), labels.to(device) | ||
out = model(data, (h0, c0)) | ||
loss += criterion(out, labels).item() | ||
_, predicted = torch.max(out.data, 1) | ||
epoch_metrics.update_performance(predicted, labels) | ||
|
||
log_str = epoch_metrics.summarize() | ||
# Local client logging. | ||
log( | ||
INFO, | ||
f"Client Validation Loss: {loss/n_batches}\nClient Validation Metrics:{log_str}", | ||
) | ||
return loss / n_batches, epoch_metrics.results | ||
|
||
|
||
class NewsClassifier(NumpyFlClient): | ||
def __init__(self, data_path: Path, device: torch.device) -> None: | ||
super().__init__(data_path, device) | ||
self.parameter_exchanger = FullParameterExchanger() | ||
|
||
def setup_client(self, config: Config) -> None: | ||
super().setup_client(config) | ||
from fl4health.checkpointing.checkpointer import TorchCheckpointer | ||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.utils.losses import LossMeter, LossMeterType | ||
from fl4health.utils.metrics import MetricMeterManager | ||
|
||
|
||
class NewsClassifierClient(BasicClient): | ||
def __init__( | ||
self, | ||
data_path: Path, | ||
device: torch.device, | ||
loss_meter_type: LossMeterType = LossMeterType.AVERAGE, | ||
checkpointer: Optional[TorchCheckpointer] = None, | ||
) -> None: | ||
super(BasicClient, self).__init__(data_path, device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: This is a slight hack that likely won't be necessary with additional refactors to the metrics managers class. Since I'm not explicitly defining a set of metric for my MetricMeter class, I am skipping over the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for future reference, for #69 I slightly refactored how we handle accumulating and computing metrics, so I decided to base off this branch and update it so we call the BasicClient constructor, and save us from having to redefine a lot of attributes. |
||
self.checkpointer = checkpointer | ||
self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) | ||
self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) | ||
|
||
self.model: nn.Module | ||
self.optimizer: torch.optim.Optimizer | ||
|
||
self.train_loader: DataLoader | ||
self.val_loader: DataLoader | ||
self.num_train_samples: int | ||
self.num_val_samples: int | ||
self.learning_rate: float | ||
self.weight_matrix: torch.Tensor | ||
self.vocabulary: Vocabulary | ||
self.label_encoder: LabelEncoder | ||
self.batch_size: int | ||
|
||
# Need to track total_steps across rounds for WANDB reporting | ||
self.total_steps: int = 0 | ||
|
||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: | ||
sequence_length = self.narrow_config_type(config, "sequence_length", int) | ||
batch_size = self.narrow_config_type(config, "batch_size", int) | ||
vocab_dimension = self.narrow_config_type(config, "vocab_dimension", int) | ||
hidden_size = self.narrow_config_type(config, "hidden_size", int) | ||
vocabulary = Vocabulary.from_json(self.narrow_config_type(config, "vocabulary", str)) | ||
label_encoder = LabelEncoder.from_json(self.narrow_config_type(config, "label_encoder", str)) | ||
self.batch_size = self.narrow_config_type(config, "batch_size", int) | ||
|
||
train_loader, validation_loader, num_examples, weight_matrix = construct_dataloaders( | ||
self.data_path, vocabulary, label_encoder, sequence_length, batch_size | ||
train_loader, validation_loader, _, weight_matrix = construct_dataloaders( | ||
self.data_path, self.vocabulary, self.label_encoder, sequence_length, self.batch_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe adding a comment for future reference when people are looking at the code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call. |
||
) | ||
|
||
self.train_loader = train_loader | ||
self.validation_loader = validation_loader | ||
self.num_examples = num_examples | ||
self.label_encoder = label_encoder | ||
self.weight_matrix = weight_matrix | ||
|
||
# Model requires vocabularly and server settings, should only be setup once | ||
self.setup_model(vocabulary.vocabulary_size, vocab_dimension, hidden_size) | ||
return train_loader, validation_loader | ||
|
||
def setup_model(self, vocab_size: int, vocab_dimension: int, hidden_size: int) -> None: | ||
self.model = LSTM(vocab_size, vocab_dimension, hidden_size) | ||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss(weight=self.weight_matrix) | ||
|
||
def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: | ||
if not self.initialized: | ||
self.setup_client(config) | ||
# Once the model is created model weights are initialized from server | ||
self.set_parameters(parameters, config) | ||
def get_optimizer(self, config: Config) -> Optimizer: | ||
return torch.optim.AdamW(self.model.parameters(), lr=0.01, weight_decay=0.001) | ||
|
||
local_epochs = self.narrow_config_type(config, "local_epochs", int) | ||
def get_model(self, config: Config) -> nn.Module: | ||
vocab_dimension = self.narrow_config_type(config, "vocab_dimension", int) | ||
hidden_size = self.narrow_config_type(config, "hidden_size", int) | ||
return LSTM(self.vocabulary.vocabulary_size, vocab_dimension, hidden_size) | ||
|
||
fit_metrics = train( | ||
self.model, | ||
self.train_loader, | ||
local_epochs, | ||
self.label_encoder, | ||
self.weight_matrix, | ||
device=self.device, | ||
) | ||
# Result should contain local parameters, number of examples on client, and a dictionary holding metrics | ||
# calculation results. | ||
return ( | ||
self.get_parameters(config), | ||
self.num_examples["train_set"], | ||
fit_metrics, | ||
) | ||
def setup_client(self, config: Config) -> None: | ||
self.vocabulary = Vocabulary.from_json(self.narrow_config_type(config, "vocabulary", str)) | ||
self.label_encoder = LabelEncoder.from_json(self.narrow_config_type(config, "label_encoder", str)) | ||
# Define mapping from prediction key to meter to pass to MetricMeterManager constructor for train and val | ||
train_key_to_meter_map: Dict[str, MetricMeter] = {"prediction": CustomMetricMeter(self.label_encoder)} | ||
self.train_metric_meter_mngr = MetricMeterManager(train_key_to_meter_map) | ||
val_key_to_meter_map: Dict[str, MetricMeter] = {"prediction": CustomMetricMeter(self.label_encoder)} | ||
self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map) | ||
super().setup_client(config) | ||
|
||
def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: | ||
if not self.initialized: | ||
self.setup_client(config) | ||
|
||
self.set_parameters(parameters, config) | ||
loss, evaluate_metrics = validate(self.model, self.validation_loader, self.label_encoder, self.device) | ||
# Result should return the loss, number of examples on client, and a dictionary holding metrics | ||
# calculation results. | ||
return ( | ||
loss, | ||
self.num_examples["validation_set"], | ||
evaluate_metrics, | ||
) | ||
def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: | ||
# While this isn't optimal, this is a good example of a custom predict function to manipulate the predictions | ||
assert isinstance(self.model, LSTM) | ||
h0, c0 = self.model.init_hidden(self.batch_size) | ||
h0 = h0.to(self.device) | ||
c0 = c0.to(self.device) | ||
preds = self.model(input, (h0, c0)) | ||
return {"prediction": preds} | ||
|
||
|
||
if __name__ == "__main__": | ||
|
@@ -181,5 +98,5 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di | |
# Load model and data | ||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
data_path = Path(args.dataset_path) | ||
client = NewsClassifier(data_path, DEVICE) | ||
client = NewsClassifierClient(data_path, DEVICE) | ||
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this file simply relate to moving from the previous news classification dataset to the ag news dataset.