-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adopt Disk-Based Data Sharing for Reliable Inter-Process Communication (
#7) * fix: restrict data sharing to disk * fix: warn when data is already shared * fix: implement local_process in ParallelClientTrainer * chore: update version
- Loading branch information
1 parent
97e72b2
commit 0332082
Showing
6 changed files
with
104 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,13 @@ | ||
[project] | ||
name = "blazefl" | ||
version = "0.1.0b2" | ||
version = "0.1.0b3" | ||
description = "A blazing-fast and lightweight simulation framework for Federated Learning." | ||
readme = "README.md" | ||
authors = [ | ||
{ name = "kitsuya0828", email = "[email protected]" } | ||
] | ||
requires-python = ">=3.12" | ||
dependencies = [ | ||
"pydantic>=2.10.3", | ||
"torch>=2.5.1", | ||
"torchvision>=0.20.1", | ||
"tqdm>=4.67.1", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,58 @@ | ||
import multiprocessing as mp | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
from multiprocessing.pool import ApplyResult | ||
from pathlib import Path | ||
from typing import Generic, TypeVar | ||
|
||
import torch | ||
from tqdm import tqdm | ||
|
||
class SerialClientTrainer(ABC): | ||
UplinkPackage = TypeVar("UplinkPackage") | ||
DownlinkPackage = TypeVar("DownlinkPackage") | ||
|
||
|
||
class SerialClientTrainer(ABC, Generic[UplinkPackage, DownlinkPackage]): | ||
@abstractmethod | ||
def uplink_package(self) -> list[Any]: ... | ||
def uplink_package(self) -> list[UplinkPackage]: ... | ||
|
||
@abstractmethod | ||
def local_process(self, payload: Any, cid_list: list[int]) -> None: ... | ||
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ... | ||
|
||
|
||
DiskSharedData = TypeVar("DiskSharedData") | ||
|
||
|
||
class ParallelClientTrainer( | ||
SerialClientTrainer[UplinkPackage, DownlinkPackage], | ||
Generic[UplinkPackage, DownlinkPackage, DiskSharedData], | ||
): | ||
def __init__(self, num_parallels: int, tmp_dir: Path): | ||
self.num_parallels = num_parallels | ||
self.tmp_dir = tmp_dir | ||
self.cache: list[UplinkPackage] = [] | ||
|
||
class ParallelClientTrainer(ABC): | ||
@abstractmethod | ||
def uplink_package(self) -> list[Any]: ... | ||
def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData: ... | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def local_process(self, payload: Any, cid_list: list[int]) -> None: ... | ||
def process_client(path: Path) -> Path: ... | ||
|
||
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: | ||
pool = mp.Pool(processes=self.num_parallels) | ||
jobs: list[ApplyResult] = [] | ||
|
||
for cid in cid_list: | ||
path = self.tmp_dir.joinpath(f"{cid}.pkl") | ||
data = self.get_shared_data(cid, payload) | ||
torch.save(data, path) | ||
jobs.append(pool.apply_async(self.process_client, (path,))) | ||
|
||
for job in tqdm(jobs, desc="Client", leave=False): | ||
path = job.get() | ||
assert isinstance(path, Path) | ||
package = torch.load(path, weights_only=False) | ||
self.cache.append(package) | ||
|
||
pool.close() | ||
pool.join() |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.