Skip to content

Commit

Permalink
Improve ClientTrainer with Seed Support (#12)
Browse files Browse the repository at this point in the history
* fix: remove torchvision from dependencies

* feat: add FilteredDataset to utils

* fix: use ModelSelector consistently

* fix: initialize num_parallels and share_dir in base class

* feat: seed support
  • Loading branch information
kitsuya0828 authored Dec 11, 2024
1 parent ce5ea33 commit 2f181c2
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 116 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "blazefl"
version = "0.1.0b4"
version = "0.1.0b5"
description = "A blazing-fast and lightweight simulation framework for Federated Learning."
readme = "README.md"
authors = [
Expand All @@ -16,8 +16,8 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
]
dependencies = [
"numpy>=2.2.0",
"torch>=2.5.1",
"torchvision>=0.20.1",
"tqdm>=4.67.1",
"types-tqdm>=4.67.0.20241119",
]
Expand Down
57 changes: 38 additions & 19 deletions src/blazefl/contrib/fedavg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
from dataclasses import dataclass
from logging import Logger
from pathlib import Path

import torch
Expand All @@ -14,7 +13,12 @@
SerialClientTrainer,
ServerHandler,
)
from blazefl.utils.serialize import deserialize_model, serialize_model
from blazefl.utils import (
RandomState,
deserialize_model,
seed_everything,
serialize_model,
)


@dataclass
Expand All @@ -31,21 +35,20 @@ class FedAvgDownlinkPackage:
class FedAvgServerHandler(ServerHandler):
def __init__(
self,
model: torch.nn.Module,
model_selector: ModelSelector,
model_name: str,
dataset: PartitionedDataset,
global_round: int,
num_clients: int,
sample_ratio: float,
device: str,
logger: Logger,
) -> None:
self.model = model
self.model = model_selector.select_model(model_name)
self.dataset = dataset
self.global_round = global_round
self.num_clients = num_clients
self.sample_ratio = sample_ratio
self.device = device
self.logger = logger

self.client_buffer_cache: list[FedAvgUplinkPackage] = []
self.num_clients_per_round = int(self.num_clients * self.sample_ratio)
Expand Down Expand Up @@ -95,18 +98,21 @@ def downlink_package(self) -> FedAvgDownlinkPackage:
return FedAvgDownlinkPackage(model_parameters)


class FedAvgSerialClientTrainer(SerialClientTrainer):
class FedAvgSerialClientTrainer(
SerialClientTrainer[FedAvgUplinkPackage, FedAvgDownlinkPackage]
):
def __init__(
self,
model: torch.nn.Module,
model_selector: ModelSelector,
model_name: str,
dataset: PartitionedDataset,
device: str,
num_clients: int,
epochs: int,
batch_size: int,
lr: float,
) -> None:
self.model = model
self.model = model_selector.select_model(model_name)
self.dataset = dataset
self.device = device
self.num_clients = num_clients
Expand All @@ -130,7 +136,9 @@ def local_process(
pack = self.train(model_parameters, data_loader)
self.cache.append(pack)

def train(self, model_parameters, train_loader) -> FedAvgUplinkPackage:
def train(
self, model_parameters: torch.Tensor, train_loader: DataLoader
) -> FedAvgUplinkPackage:
deserialize_model(self.model, model_parameters)
self.model.train()

Expand Down Expand Up @@ -167,7 +175,9 @@ class FedAvgDiskSharedData:
lr: float
device: str
cid: int
seed: int
payload: FedAvgDownlinkPackage
state_path: Path


class FedAvgParalleClientTrainer(
Expand All @@ -179,35 +189,44 @@ def __init__(
self,
model_selector: ModelSelector,
model_name: str,
tmp_dir: Path,
share_dir: Path,
state_dir: Path,
dataset: PartitionedDataset,
device: str,
num_clients: int,
epochs: int,
batch_size: int,
lr: float,
seed: int,
num_parallels: int,
) -> None:
super().__init__(num_parallels, share_dir)
self.model_selector = model_selector
self.model_name = model_name
self.state_dir = state_dir
self.dataset = dataset
self.epochs = epochs
self.batch_size = batch_size
self.lr = lr
self.tmp_dir = tmp_dir
self.tmp_dir.mkdir(parents=True, exist_ok=True)
self.device = device
self.num_clients = num_clients
self.num_parallels = num_parallels
self.seed = seed

self.cache: list[FedAvgUplinkPackage] = []
if self.device == "cuda":
self.device_count = torch.cuda.device_count()

@staticmethod
def process_client(path: Path) -> Path:
data = torch.load(path, weights_only=False)
assert isinstance(data, FedAvgDiskSharedData)

if data.state_path.exists():
state = torch.load(data.state_path)
assert isinstance(state, RandomState)
RandomState.set_random_state(state)
else:
seed_everything(data.seed, device=data.device)

model = data.model_selector.select_model(data.model_name)
train_loader = data.dataset.get_dataloader(
type_="train",
Expand All @@ -222,10 +241,8 @@ def process_client(path: Path) -> Path:
epochs=data.epochs,
lr=data.lr,
)
torch.save(
package,
path,
)
torch.save(package, path)
torch.save(RandomState.get_random_state(device=data.device), data.state_path)
return path

@staticmethod
Expand Down Expand Up @@ -278,7 +295,9 @@ def get_shared_data(
lr=self.lr,
device=device,
cid=cid,
seed=self.seed,
payload=payload,
state_path=self.state_dir,
)
return data

Expand Down
7 changes: 4 additions & 3 deletions src/blazefl/core/client_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ class ParallelClientTrainer(
SerialClientTrainer[UplinkPackage, DownlinkPackage],
Generic[UplinkPackage, DownlinkPackage, DiskSharedData],
):
def __init__(self, num_parallels: int, tmp_dir: Path):
def __init__(self, num_parallels: int, share_dir: Path) -> None:
self.num_parallels = num_parallels
self.tmp_dir = tmp_dir
self.share_dir = share_dir
self.share_dir.mkdir(parents=True, exist_ok=True)
self.cache: list[UplinkPackage] = []

@abstractmethod
Expand All @@ -43,7 +44,7 @@ def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
jobs: list[ApplyResult] = []

for cid in cid_list:
path = self.tmp_dir.joinpath(f"{cid}.pkl")
path = self.share_dir.joinpath(f"{cid}.pkl")
data = self.get_shared_data(cid, payload)
torch.save(data, path)
jobs.append(pool.apply_async(self.process_client, (path,)))
Expand Down
10 changes: 9 additions & 1 deletion src/blazefl/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from blazefl.utils.dataset import FilteredDataset
from blazefl.utils.seed import RandomState, seed_everything
from blazefl.utils.serialize import deserialize_model, serialize_model

__all__ = ["serialize_model", "deserialize_model"]
__all__ = [
"serialize_model",
"deserialize_model",
"FilteredDataset",
"seed_everything",
"RandomState",
]
36 changes: 36 additions & 0 deletions src/blazefl/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from collections.abc import Callable

from torch.utils.data import Dataset


class FilteredDataset(Dataset):
def __init__(
self,
indices: list[int],
original_data: list,
original_targets: list | None = None,
transform: Callable | None = None,
target_transform: Callable | None = None,
) -> None:
self.data = [original_data[i] for i in indices]
if original_targets is not None:
assert len(original_data) == len(original_targets)
self.targets = [original_targets[i] for i in indices]
self.transform = transform
self.target_transform = target_transform

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, index: int) -> tuple:
img = self.data[index]
if self.transform is not None:
img = self.transform(img)

if hasattr(self, "targets"):
target = self.targets[index]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target

return img
67 changes: 67 additions & 0 deletions src/blazefl/utils/seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import random
from dataclasses import dataclass

import numpy as np
import torch


def seed_everything(seed: int, device: str) -> None:
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.startswith("cuda"):
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


@dataclass
class CUDARandomState:
manual_seed: int
cudnn_deterministic: bool
cudnn_benchmark: bool


@dataclass
class RandomState:
random: tuple
environ: str
numpy: dict
torch: int
cuda: CUDARandomState | None

@classmethod
def get_random_state(cls, device: str) -> "RandomState":
if device.startswith("cuda"):
return cls(
random.getstate(),
os.environ["PYTHONHASHSEED"],
np.random.get_state(),
torch.initial_seed(),
CUDARandomState(
torch.cuda.initial_seed(),
torch.backends.cudnn.deterministic,
torch.backends.cudnn.benchmark,
),
)
return cls(
random.getstate(),
os.environ["PYTHONHASHSEED"],
np.random.get_state(),
torch.initial_seed(),
None,
)

@staticmethod
def set_random_state(random_state: "RandomState") -> None:
random.setstate(random_state.random)
os.environ["PYTHONHASHSEED"] = random_state.environ
np.random.set_state(random_state.numpy)
if random_state.cuda is not None:
torch.manual_seed(random_state.torch)
torch.cuda.manual_seed(random_state.cuda.manual_seed)
torch.backends.cudnn.deterministic = random_state.cuda.cudnn_deterministic
torch.backends.cudnn.benchmark = random_state.cuda.cudnn_benchmark
Loading

0 comments on commit 2f181c2

Please sign in to comment.