Skip to content

Commit

Permalink
Adopt Disk-Based Data Sharing for Reliable Inter-Process Communication (
Browse files Browse the repository at this point in the history
#7)

* fix: restrict data sharing to disk

* fix: warn when data is already shared

* fix: implement local_process in ParallelClientTrainer

* chore: update version
  • Loading branch information
kitsuya0828 authored Dec 9, 2024
1 parent 97e72b2 commit 0332082
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 172 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
[project]
name = "blazefl"
version = "0.1.0b2"
version = "0.1.0b3"
description = "A blazing-fast and lightweight simulation framework for Federated Learning."
readme = "README.md"
authors = [
{ name = "kitsuya0828", email = "[email protected]" }
]
requires-python = ">=3.12"
dependencies = [
"pydantic>=2.10.3",
"torch>=2.5.1",
"torchvision>=0.20.1",
"tqdm>=4.67.1",
Expand Down
110 changes: 55 additions & 55 deletions src/blazefl/contrib/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from pathlib import Path

import torch
import torch.multiprocessing as mp
from pydantic import BaseModel, ConfigDict
from torch.utils.data import DataLoader
from tqdm import tqdm

from blazefl.core import (
Expand All @@ -14,7 +13,6 @@
PartitionedDataset,
SerialClientTrainer,
ServerHandler,
SharedData,
)
from blazefl.utils.serialize import deserialize_model, serialize_model

Expand Down Expand Up @@ -159,8 +157,8 @@ def uplink_package(self) -> list[FedAvgUplinkPackage]:
return self.cache


class FedAvgSharedMemoryData(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass
class FedAvgDiskSharedData:
model_selector: ModelSelector
model_name: str
dataset: PartitionedDataset
Expand All @@ -169,14 +167,14 @@ class FedAvgSharedMemoryData(BaseModel):
lr: float
device: str
cid: int
payload: FedAvgDownlinkPackage


class FedAvgDiskData(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model_parameters: torch.Tensor


class FedAvgParalleClientTrainer(ParallelClientTrainer):
class FedAvgParalleClientTrainer(
ParallelClientTrainer[
FedAvgUplinkPackage, FedAvgDownlinkPackage, FedAvgDiskSharedData
]
):
def __init__(
self,
model_selector: ModelSelector,
Expand Down Expand Up @@ -207,31 +205,49 @@ def __init__(
self.device_count = torch.cuda.device_count()

@staticmethod
def process_client(
shared_data: SharedData[FedAvgSharedMemoryData, FedAvgDiskData],
) -> FedAvgUplinkPackage:
shared_memory_data = shared_data.get_shared_memory_data()
disk_data = shared_data.get_disk_data()
model = shared_memory_data.model_selector.select_model(
shared_memory_data.model_name
def process_client(path: Path) -> Path:
data = torch.load(path, weights_only=False)
assert isinstance(data, FedAvgDiskSharedData)
model = data.model_selector.select_model(data.model_name)
train_loader = data.dataset.get_dataloader(
type_="train",
cid=data.cid,
batch_size=data.batch_size,
)
package = FedAvgParalleClientTrainer.train(
model=model,
model_parameters=data.payload.model_parameters,
train_loader=train_loader,
device=data.device,
epochs=data.epochs,
lr=data.lr,
)
torch.save(
package,
path,
)
return path

deserialize_model(model, disk_data.model_parameters)
model.to(shared_memory_data.device)
@staticmethod
def train(
model: torch.nn.Module,
model_parameters: torch.Tensor,
train_loader: DataLoader,
device: str,
epochs: int,
lr: float,
) -> FedAvgUplinkPackage:
model.to(device)
deserialize_model(model, model_parameters)
model.train()
train_loader = shared_memory_data.dataset.get_dataloader(
type_="train",
cid=shared_memory_data.cid,
batch_size=shared_memory_data.batch_size,
)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=shared_memory_data.lr)

data_size = 0
for _ in range(shared_memory_data.epochs):
for _ in range(epochs):
for data, target in train_loader:
data = data.to(shared_memory_data.device)
target = target.to(shared_memory_data.device)
data = data.to(device)
target = target.to(device)

output = model(data)
loss = criterion(output, target)
Expand All @@ -243,44 +259,28 @@ def process_client(
optimizer.step()

model_parameters = serialize_model(model)

return FedAvgUplinkPackage(model_parameters, data_size)

def get_shared_data(
self, cid: int, payload: FedAvgDownlinkPackage
) -> SharedData[FedAvgSharedMemoryData, FedAvgDiskData]:
shared_memory_data = FedAvgSharedMemoryData(
) -> FedAvgDiskSharedData:
if self.device == "cuda":
device = f"cuda:{cid % self.device_count}"
else:
device = self.device
data = FedAvgDiskSharedData(
model_selector=self.model_selector,
model_name=self.model_name,
dataset=self.dataset,
epochs=self.epochs,
batch_size=self.batch_size,
lr=self.lr,
device=f"cuda:{cid % self.device_count}"
if self.device == "cuda"
else self.device,
device=device,
cid=cid,
payload=payload,
)
disk_data = FedAvgDiskData(model_parameters=payload.model_parameters)
shared_data = SharedData(
shared_memory_data=shared_memory_data,
disk_data=disk_data,
disk_path=self.tmp_dir.joinpath(f"{cid}.pt"),
)
return shared_data

def local_process(
self, payload: FedAvgDownlinkPackage, cid_list: list[int]
) -> None:
pool = mp.Pool(processes=self.num_parallels)
jobs = []
for cid in cid_list:
client_shared_data = self.get_shared_data(cid, payload).share()
jobs.append(pool.apply_async(self.process_client, (client_shared_data,)))

for job in tqdm(jobs, desc="Client", leave=False):
result = job.get()
assert isinstance(result, FedAvgUplinkPackage)
self.cache.append(result)
return data

def uplink_package(self) -> list[FedAvgUplinkPackage]:
return self.cache
2 changes: 0 additions & 2 deletions src/blazefl/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from blazefl.core.model_selector import ModelSelector
from blazefl.core.partitioned_dataset import PartitionedDataset
from blazefl.core.server_handler import ServerHandler
from blazefl.core.shared_data import SharedData

__all__ = [
"SerialClientTrainer",
"ParallelClientTrainer",
"ModelSelector",
"PartitionedDataset",
"ServerHandler",
"SharedData",
]
54 changes: 47 additions & 7 deletions src/blazefl/core/client_trainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,58 @@
import multiprocessing as mp
from abc import ABC, abstractmethod
from typing import Any
from multiprocessing.pool import ApplyResult
from pathlib import Path
from typing import Generic, TypeVar

import torch
from tqdm import tqdm

class SerialClientTrainer(ABC):
UplinkPackage = TypeVar("UplinkPackage")
DownlinkPackage = TypeVar("DownlinkPackage")


class SerialClientTrainer(ABC, Generic[UplinkPackage, DownlinkPackage]):
@abstractmethod
def uplink_package(self) -> list[Any]: ...
def uplink_package(self) -> list[UplinkPackage]: ...

@abstractmethod
def local_process(self, payload: Any, cid_list: list[int]) -> None: ...
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...


DiskSharedData = TypeVar("DiskSharedData")


class ParallelClientTrainer(
SerialClientTrainer[UplinkPackage, DownlinkPackage],
Generic[UplinkPackage, DownlinkPackage, DiskSharedData],
):
def __init__(self, num_parallels: int, tmp_dir: Path):
self.num_parallels = num_parallels
self.tmp_dir = tmp_dir
self.cache: list[UplinkPackage] = []

class ParallelClientTrainer(ABC):
@abstractmethod
def uplink_package(self) -> list[Any]: ...
def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData: ...

@staticmethod
@abstractmethod
def local_process(self, payload: Any, cid_list: list[int]) -> None: ...
def process_client(path: Path) -> Path: ...

def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
pool = mp.Pool(processes=self.num_parallels)
jobs: list[ApplyResult] = []

for cid in cid_list:
path = self.tmp_dir.joinpath(f"{cid}.pkl")
data = self.get_shared_data(cid, payload)
torch.save(data, path)
jobs.append(pool.apply_async(self.process_client, (path,)))

for job in tqdm(jobs, desc="Client", leave=False):
path = job.get()
assert isinstance(path, Path)
package = torch.load(path, weights_only=False)
self.cache.append(package)

pool.close()
pool.join()
41 changes: 0 additions & 41 deletions src/blazefl/core/shared_data.py

This file was deleted.

Loading

0 comments on commit 0332082

Please sign in to comment.