-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve ClientTrainer with Seed Support (#12)
* fix: remove torchvision from dependencies * feat: add FilteredDataset to utils * fix: use ModelSelector consistently * fix: initialize num_parallels and share_dir in base class * feat: seed support
- Loading branch information
1 parent
ce5ea33
commit 2f181c2
Showing
7 changed files
with
193 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,11 @@ | ||
from blazefl.utils.dataset import FilteredDataset | ||
from blazefl.utils.seed import RandomState, seed_everything | ||
from blazefl.utils.serialize import deserialize_model, serialize_model | ||
|
||
__all__ = ["serialize_model", "deserialize_model"] | ||
__all__ = [ | ||
"serialize_model", | ||
"deserialize_model", | ||
"FilteredDataset", | ||
"seed_everything", | ||
"RandomState", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from collections.abc import Callable | ||
|
||
from torch.utils.data import Dataset | ||
|
||
|
||
class FilteredDataset(Dataset): | ||
def __init__( | ||
self, | ||
indices: list[int], | ||
original_data: list, | ||
original_targets: list | None = None, | ||
transform: Callable | None = None, | ||
target_transform: Callable | None = None, | ||
) -> None: | ||
self.data = [original_data[i] for i in indices] | ||
if original_targets is not None: | ||
assert len(original_data) == len(original_targets) | ||
self.targets = [original_targets[i] for i in indices] | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
|
||
def __len__(self) -> int: | ||
return len(self.data) | ||
|
||
def __getitem__(self, index: int) -> tuple: | ||
img = self.data[index] | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
if hasattr(self, "targets"): | ||
target = self.targets[index] | ||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
return img, target | ||
|
||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import random | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
def seed_everything(seed: int, device: str) -> None: | ||
random.seed(seed) | ||
os.environ["PYTHONHASHSEED"] = str(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
if device.startswith("cuda"): | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
|
||
@dataclass | ||
class CUDARandomState: | ||
manual_seed: int | ||
cudnn_deterministic: bool | ||
cudnn_benchmark: bool | ||
|
||
|
||
@dataclass | ||
class RandomState: | ||
random: tuple | ||
environ: str | ||
numpy: dict | ||
torch: int | ||
cuda: CUDARandomState | None | ||
|
||
@classmethod | ||
def get_random_state(cls, device: str) -> "RandomState": | ||
if device.startswith("cuda"): | ||
return cls( | ||
random.getstate(), | ||
os.environ["PYTHONHASHSEED"], | ||
np.random.get_state(), | ||
torch.initial_seed(), | ||
CUDARandomState( | ||
torch.cuda.initial_seed(), | ||
torch.backends.cudnn.deterministic, | ||
torch.backends.cudnn.benchmark, | ||
), | ||
) | ||
return cls( | ||
random.getstate(), | ||
os.environ["PYTHONHASHSEED"], | ||
np.random.get_state(), | ||
torch.initial_seed(), | ||
None, | ||
) | ||
|
||
@staticmethod | ||
def set_random_state(random_state: "RandomState") -> None: | ||
random.setstate(random_state.random) | ||
os.environ["PYTHONHASHSEED"] = random_state.environ | ||
np.random.set_state(random_state.numpy) | ||
if random_state.cuda is not None: | ||
torch.manual_seed(random_state.torch) | ||
torch.cuda.manual_seed(random_state.cuda.manual_seed) | ||
torch.backends.cudnn.deterministic = random_state.cuda.cudnn_deterministic | ||
torch.backends.cudnn.benchmark = random_state.cuda.cudnn_benchmark |
Oops, something went wrong.