From 254803572b17c59f5a2b909496a9b5a05a5b3c38 Mon Sep 17 00:00:00 2001 From: wnma3mz Date: Thu, 3 Feb 2022 21:15:26 +0800 Subject: [PATCH] update MOON (update df and ccvr) and (add cli for ccvr and df) --- example/MOON_reproduction/MyClients.py | 4 - example/MOON_reproduction/MyStrategys.py | 179 ++++++++--------- example/MOON_reproduction/MyTrainers.py | 244 ++++++----------------- example/MOON_reproduction/main.py | 155 ++++++-------- 4 files changed, 217 insertions(+), 365 deletions(-) diff --git a/example/MOON_reproduction/MyClients.py b/example/MOON_reproduction/MyClients.py index 4a39a05..85ee0e5 100644 --- a/example/MOON_reproduction/MyClients.py +++ b/example/MOON_reproduction/MyClients.py @@ -1,5 +1,4 @@ # coding: utf-8 - import copy from flearn.client import Client @@ -91,10 +90,7 @@ def revice(self, i, glob_params): self.scheduler.step() # self.trainer.model.load_state_dict(self.w_local_bak) self.trainer.model.load_state_dict(update_w) - self.trainer.server_model = copy.deepcopy(self.trainer.model) - self.trainer.server_model.load_state_dict(update_w) - self.trainer.server_model.eval() self.trainer.server_state_dict = copy.deepcopy(update_w) return { "code": 200, diff --git a/example/MOON_reproduction/MyStrategys.py b/example/MOON_reproduction/MyStrategys.py index a6183aa..991f27c 100644 --- a/example/MOON_reproduction/MyStrategys.py +++ b/example/MOON_reproduction/MyStrategys.py @@ -6,7 +6,6 @@ import torch.nn as nn import torch.optim as optim -from flearn.common import Encrypt from flearn.common.distiller import Distiller, KDLoss from flearn.common.strategy import AVG @@ -24,13 +23,14 @@ def client_revice(self, trainer, data_glob_d): return w_local, logits_glob def server(self, ensemble_params_lst, round_): - g_shared = super(Distill, self).server(ensemble_params_lst, round_) + ensemble_params = super(Distill, self).server(ensemble_params_lst, round_) logits_lst = self.extract_lst(ensemble_params_lst, "logits") - g_shared["logits_glob"] = self.aggregate_logits(logits_lst) - return g_shared + ensemble_params["logits_glob"] = self.aggregate_logits(logits_lst) + return ensemble_params - def aggregate_logits(self, logits_lst): + @staticmethod + def aggregate_logits(logits_lst): user_logits = 0 for item in logits_lst: user_logits += item @@ -56,12 +56,42 @@ def dyn_f(self, w_glob, w_local_lst): for k in self.h.keys(): w_glob[k] = w_glob[k] - self.alpha * self.h[k] self.theta = w_glob + return w_glob + + def server_post_processing(self, ensemble_params_lst, ensemble_params): + w_local_lst = self.extract_lst(ensemble_params_lst, "params") + ensemble_params["w_glob"] = self.dyn_f(ensemble_params["w_glob"], w_local_lst) + return ensemble_params + + def server(self, ensemble_params_lst, round_): + ensemble_params = super().server(ensemble_params_lst, round_) + return self.server_post_processing(ensemble_params_lst, ensemble_params) + + +class ParentStrategy(AVG): + def __init__(self, model_fpath, strategy=None, **kwargs): + self.strategy = strategy + if self.strategy != None: + key_lst = list(self.strategy.__dict__.keys()) + for k, v in kwargs.items(): + if k in key_lst: + self.__dict__[k] = v + super().__init__(model_fpath) + + def client(self, trainer, agg_weight=1.0): + if self.strategy != None: + return self.strategy.client(trainer, agg_weight) + return super().client(trainer, agg_weight) def server(self, ensemble_params_lst, round_): - g_shared = super(Dyn, self).server(ensemble_params_lst, round_) - _, w_local_lst = self.server_pre_processing(ensemble_params_lst) - self.dyn_f(g_shared["w_glob"], w_local_lst) - return g_shared + if self.strategy != None: + return self.strategy.server(ensemble_params_lst, round_) + return super().server(ensemble_params_lst, round_) + + def client_revice(self, trainer, data_glob_d): + if self.strategy != None: + return self.strategy.client_revice(trainer, data_glob_d) + return super().client_revice(trainer, data_glob_d) class DFDistiller(Distiller): @@ -111,7 +141,8 @@ def multi( soft_target_lst = [] for teacher in self.teacher_lst: with torch.no_grad(): - soft_target_lst.append(teacher(x)) + _, _, soft_target = teacher(x) + soft_target_lst.append(soft_target) loss = self.multi_loss(method, soft_target_lst, output) @@ -122,18 +153,13 @@ def multi( return student.state_dict() -class DF(AVG): - """ - Ensemble distillation for robust model fusion in federated learning - - [1] Lin T, Kong L, Stich S U, et al. Ensemble distillation for robust model fusion in federated learning[J]. arXiv preprint arXiv:2006.07242, 2020. - """ - - def __init__(self, model_fpath, model_base): - super().__init__(model_fpath) +class DF(ParentStrategy): + def __init__(self, model_fpath, model_base, strategy=None, **kwargs): + super().__init__(model_fpath, strategy, **kwargs) self.model_base = model_base - def ensemble_w(self, ensemble_params_lst, w_glob, **kwargs): + def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs): + w_glob = ensemble_params["w_glob"] agg_weight_lst, w_local_lst = self.server_pre_processing(ensemble_params_lst) teacher_lst = [] @@ -144,19 +170,21 @@ def ensemble_w(self, ensemble_params_lst, w_glob, **kwargs): self.model_base.load_state_dict(w_glob) student = copy.deepcopy(self.model_base) - self.distiller = DFDistiller( - kwargs.pop("kd_loader"), - kwargs.pop("device"), - kd_loss=KDLoss(kwargs.pop("T")), + kd_loader, device = kwargs.pop("kd_loader"), kwargs.pop("device") + temperature = kwargs.pop("T") + distiller = DFDistiller( + kd_loader, + device, + kd_loss=KDLoss(temperature), ) molecular = np.sum(agg_weight_lst) weight_lst = [w / molecular for w in agg_weight_lst] # agg_weight_lst:应该依照每个模型在验证集上的性能来进行分配 - w_glob = self.distiller.multi( + ensemble_params["w_glob"] = distiller.multi( teacher_lst, student, kwargs.pop("method"), weight_lst=weight_lst, **kwargs ) - return w_glob + return ensemble_params def server(self, ensemble_params_lst, round_, **kwargs): """ @@ -169,11 +197,10 @@ def server(self, ensemble_params_lst, round_, **kwargs): "kd_loader": 蒸馏数据集,仅需输入,无需标签 } """ - g_shared = super(DF, self).server(ensemble_params_lst, round_) - g_shared["w_glob"] = self.ensemble_w( - ensemble_params_lst, g_shared["w_glob"], **kwargs + ensemble_params = super().server(ensemble_params_lst, round_) + return self.server_post_processing( + ensemble_params_lst, ensemble_params, **kwargs ) - return g_shared class ReTrain: @@ -209,25 +236,10 @@ def run(self, student, lr=0.01): return student.state_dict() -class CCVR(Distill, Dyn, DF): - def __init__( - self, - model_fpath, - glob_model_base, - model_base=None, - strategy=None, - h=None, - shared_key_layers=None, - ): - # kwargs = {"model_fpath": model_fpath, "model_base": model_base, "h": h} - # super().__init__(**kwargs) - self.model_fpath = model_fpath - self.model_base = model_base - self.h = h - self.glob_model = glob_model_base - self.strategy = strategy - self.shared_key_layers = shared_key_layers - self.encrypt = Encrypt() +class CCVR(ParentStrategy): + def __init__(self, model_fpath, glob_model_base, strategy=None, **kwargs): + super().__init__(model_fpath, strategy, **kwargs) + self.glob_model_base = glob_model_base @staticmethod def client_mean_feat(feat_lst, label_lst): @@ -256,16 +268,7 @@ def client_mean_feat(feat_lst, label_lst): return upload_d def client(self, trainer, agg_weight=1.0): - if self.strategy == "lg": - w_shared = {"agg_weight": agg_weight} - w_local = trainer.weight - w_shared["params"] = {k: w_local[k].cpu() for k in self.shared_key_layers} - else: - w_shared = super(CCVR, self).client(trainer, agg_weight) - - if self.strategy == "distill": - w_shared["logits"] = trainer.logit_tracker.avg() - + w_shared = super().client(trainer, agg_weight) w_shared["fd"] = self.client_mean_feat(trainer.feat_lst, trainer.label_lst) return w_shared @@ -312,47 +315,41 @@ def server_mean_feat(self, fd_lst): return fd_d - def server(self, ensemble_params_lst, round_, **kwargs): - agg_weight_lst, w_local_lst = self.server_pre_processing(ensemble_params_lst) - try: - w_glob = self.server_ensemble(agg_weight_lst, w_local_lst) - except Exception as e: - return self.server_exception(e) - g_shared = {"w_glob": w_glob} - - if self.strategy == "dyn": - _, w_local_lst = self.server_pre_processing(ensemble_params_lst) - self.dyn_f(g_shared["w_glob"], w_local_lst) - elif self.strategy == "distill": - logits_lst = self.extract_lst(ensemble_params_lst, "logits") - g_shared["logits_glob"] = self.aggregate_logits(logits_lst) - elif self.strategy == "df": - g_shared["w_glob"] = self.ensemble_w( - ensemble_params_lst, g_shared["w_glob"], **kwargs - ) - + def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs): # 特征参数提取 fd_lst = self.extract_lst(ensemble_params_lst, "fd") fd_d = self.server_mean_feat(fd_lst) # 重新训练分类器 self.retrainer = ReTrain(fd_d, kwargs["device"]) - self.glob_model = self.load_model(self.glob_model, g_shared["w_glob"]) - w_train = self.retrainer.run(self.glob_model) + self.glob_model_base = self.load_model( + self.glob_model_base, ensemble_params["w_glob"] + ) + w_train = self.retrainer.run(self.glob_model_base) for k in w_train.keys(): - g_shared["w_glob"][k] = w_train[k].cpu() + ensemble_params["w_glob"][k] = w_train[k].cpu() - return g_shared + return ensemble_params - def client_revice(self, trainer, data_glob_d): - w_local = trainer.weight - w_glob = data_glob_d["w_glob"] - for k in w_glob.keys(): - w_local[k] = w_glob[k] + def server(self, ensemble_params_lst, round_, **kwargs): + ensemble_params = super().server(ensemble_params_lst, round_) + return self.server_post_processing( + ensemble_params_lst, ensemble_params, **kwargs + ) - if self.strategy == "distill": - logits_glob = data_glob_d["logits_glob"] - return w_local, logits_glob - return w_local +class DFCCVR(CCVR): + def __init__( + self, model_fpath, model_base, glob_model_base, strategy=None, **kwargs + ): + super().__init__(model_fpath, glob_model_base, strategy, **kwargs) + self.model_base = model_base + self.df = DF(model_fpath, model_base, strategy, **kwargs) + + def server(self, ensemble_params_lst, round_, **kwargs): + # 先DF后CCVR + ensemble_params = self.df.server(ensemble_params_lst, round_, **kwargs) + return self.server_post_processing( + ensemble_params_lst, ensemble_params, **kwargs + ) diff --git a/example/MOON_reproduction/MyTrainers.py b/example/MOON_reproduction/MyTrainers.py index 30c4745..2d05276 100644 --- a/example/MOON_reproduction/MyTrainers.py +++ b/example/MOON_reproduction/MyTrainers.py @@ -12,18 +12,9 @@ class AVGTrainer(Trainer): - def batch(self, data, target): + def forward(self, data, target): _, _, output = self.model(data) - loss = self.criterion(output, target) - - if self.model.training: - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - iter_loss = loss.data.item() - iter_acc = self.metrics(output, target) - return iter_loss, iter_acc + return output class MOONTrainer(Trainer): @@ -37,8 +28,9 @@ def __init__(self, model, optimizer, criterion, device, display=True): # CIFAR-10, CIFAR-100, and Tiny-Imagenet are 5, 1, and 1 self.mu = 5 - def moon_loss(self, data, pro1): + def fed_loss(self): if self.global_model != None: + data, pro1 = self.data, self.pro1 # 全局与本地的对比损失,越小越好 with torch.no_grad(): _, pro2, _ = self.global_model(data) @@ -59,7 +51,7 @@ def moon_loss(self, data, pro1): else: return 0 - def moon_eval_model(self): + def eval_model(self): for previous_net in self.previous_model_lst: previous_net.eval() previous_net.to(self.device) @@ -68,32 +60,20 @@ def moon_eval_model(self): self.global_model.eval() self.global_model.to(self.device) - def train(self, data_loader, epochs=1): - self.moon_eval_model() - return super(MOONTrainer, self).train(data_loader, epochs) - - def batch(self, data, target): + def forward(self, data, target): _, pro1, output = self.model(data) - loss = self.criterion(output, target) - if self.model.training: - loss += self.moon_loss(data, pro1) - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - iter_loss = loss.data.item() - iter_acc = self.metrics(output, target) - return iter_loss, iter_acc + self.data, self.pro1 = data, pro1 + return output -class ProxTrainer(Trainer): +class ProxTrainer(AVGTrainer): def __init__(self, model, optimizer, criterion, device, display=True): super(ProxTrainer, self).__init__(model, optimizer, criterion, device, display) self.global_model = None # CIFAR-10, CIFAR-100, and Tiny-Imagenet are 0.01, 0.001, and 0.001 self.prox_mu = 0.01 - def prox_loss(self): + def fed_loss(self): if self.global_model != None: w_diff = torch.tensor(0.0, device=self.device) for w, w_t in zip(self.model.parameters(), self.global_model.parameters()): @@ -102,26 +82,16 @@ def prox_loss(self): else: return 0 - def batch(self, data, target): - _, _, output = self.model(data) - loss = self.criterion(output, target) - - if self.model.training: - loss += self.prox_loss() - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - iter_loss = loss.data.item() - iter_acc = self.metrics(output, target) - return iter_loss, iter_acc + def eval_model(self): + if self.global_model != None: + self.global_model.eval() + self.global_model.to(self.device) -class DynTrainer(Trainer): +class DynTrainer(AVGTrainer): def __init__(self, model, optimizer, criterion, device, display=True): super(DynTrainer, self).__init__(model, optimizer, criterion, device, display) - self.server_model = copy.deepcopy(self.model) - self.server_state_dict = self.server_model.state_dict() + self.server_state_dict = {} # save client's gradient self.prev_grads = None @@ -134,8 +104,8 @@ def __init__(self, model, optimizer, criterion, device, display=True): self.alpha = 0.01 - def dyn_loss(self): - if self.server_model != None: + def fed_loss(self): + if self.server_state_dict != {}: # Linear penalty curr_params = None for name, param in self.model.named_parameters(): @@ -149,15 +119,14 @@ def dyn_loss(self): # Quadratic Penalty, 全局模型与客户端模型尽可能小 quad_penalty = 0.0 for name, param in self.model.named_parameters(): - quad_penalty += F.mse_loss( - param, self.server_state_dict[name], reduction="sum" - ) + server_param = self.server_state_dict[name].to(self.device) + quad_penalty += F.mse_loss(param, server_param, reduction="sum") return -lin_penalty + self.alpha / 2.0 * quad_penalty else: return 0 - def update_prev_grads(self): + def update_info(self): # update prev_grads self.prev_grads = None for param in self.model.parameters(): @@ -167,22 +136,6 @@ def update_prev_grads(self): else: self.prev_grads = torch.cat((self.prev_grads, real_grad), dim=0) - def batch(self, data, target): - _, _, output = self.model(data) - loss = self.criterion(output, target) - - if self.model.training: - loss += self.dyn_loss() - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - self.update_prev_grads() - - iter_loss = loss.data.item() - iter_acc = self.metrics(output, target) - return iter_loss, iter_acc - class LSDTrainer(Trainer): def __init__(self, model, optimizer, criterion, device, display=True): @@ -192,7 +145,7 @@ def __init__(self, model, optimizer, criterion, device, display=True): # self.mu_kd = 0.5 self.kd_loss = KDLoss(2) - def lsd_eval_model(self): + def eval_model(self): if self.teacher_model != None: self.teacher_model.eval() self.teacher_model.to(self.device) @@ -201,26 +154,18 @@ def train(self, data_loader, epochs=1): self.lsd_eval_model() return super(LSDTrainer, self).train(data_loader, epochs) - def lsd_loss(self, data, output): + def fed_loss(self): if self.teacher_model != None: + data, output = self.data, self.output with torch.no_grad(): t_h, _, t_output = self.teacher_model(data) - return self.mu_kd * self.kd_loss(output, t_output.detach()) return 0 - def batch(self, data, target): - h, _, output = self.model(data) - loss = self.criterion(output, target) - if self.model.training: - loss += self.lsd_loss(data, output) - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - iter_loss = loss.data.item() - iter_acc = self.metrics(output, target) - return iter_loss, iter_acc + def forward(self, data, target): + _, _, output = self.model(data) + self.data, self.output = data, output + return output class LogitTracker: @@ -252,122 +197,61 @@ def avg(self): class DistillTrainer(Trainer): def __init__(self, model, optimizer, criterion, device, display=True): - super(DistillTrainer, self).__init__( - model, optimizer, criterion, device, display - ) + super().__init__(model, optimizer, criterion, device, display) self.logit_tracker = LogitTracker(10) # cifar10 self.glob_logit = None self.kd_mu = 1 self.kd_loss = KDLoss(2) - def train(self, data_loader, epochs=1): - self.model.train() - epoch_loss, epoch_accuracy = [], [] - for ep in range(1, epochs + 1): - with torch.enable_grad(): - loss, accuracy = self._iteration(data_loader) - epoch_loss.append(loss) - epoch_accuracy.append(accuracy) - - # 非上传轮,清空特征 - if ep != epochs: - self.logit_tracker.clear() - - return np.mean(epoch_loss), np.mean(epoch_accuracy) - - def distill_loss(self, output, target): + def fed_loss(self): if self.glob_logit != None: + output, target = self.output, self.target self.glob_logit = self.glob_logit.to(self.device) target_p = self.glob_logit[target, :] return self.kd_mu * self.kd_loss(output, target_p) return 0 - def batch(self, data, target): - _, _, output = self.model(data) - loss = self.criterion(output, target) - - if self.model.training: - loss += self.distill_loss(output, target) - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() + def update_info(self): + # 更新上传的logits + self.logit_tracker.update(self.output, self.target) - # 更新上传的logits - self.logit_tracker.update(output, target) + def clear_info(self): + self.logit_tracker.clear() - iter_loss = loss.data.item() - iter_acc = self.metrics(output, target) - return iter_loss, iter_acc + def forward(self, data, target): + _, _, output = self.model(data) + self.output, self.target = output, target + return output -class CCVRTrainer(MOONTrainer, ProxTrainer, LSDTrainer, DistillTrainer, DynTrainer): +class CCVRTrainer(AVGTrainer): # 从左至右继承,右侧不会覆盖左侧的变量/函数 - def __init__( - self, model, optimizer, criterion, device, display=True, strategy=None - ): - super(CCVRTrainer, self).__init__(model, optimizer, criterion, device, display) - self.feat_lst = [] - self.label_lst = [] - self.fed_loss_d = { - "avg": 0, - "lg": 0, - "prox": self.prox_loss(), - "dyn": self.dyn_loss(), - } - assert strategy in ["avg", "moon", "prox", "dyn", "lsd", "distill", "lg"] - self.strategy = strategy - - def train(self, data_loader, epochs=1): - self.moon_eval_model() - self.lsd_eval_model() - - self.model.train() - epoch_loss, epoch_accuracy = [], [] - for ep in range(1, epochs + 1): - with torch.enable_grad(): - loss, accuracy = self._iteration(data_loader) - epoch_loss.append(loss) - epoch_accuracy.append(accuracy) - - # 非上传轮,清空特征 - if ep != epochs: - self.feat_lst = [] - self.label_lst = [] - - if self.strategy == "distill": - self.logit_tracker.clear() - - return np.mean(epoch_loss), np.mean(epoch_accuracy) + def __init__(self, base_trainer): + super().__init__( + base_trainer.model, + base_trainer.optimizer, + base_trainer.criterion, + base_trainer.device, + base_trainer.display, + ) + self.feat_lst, self.label_lst = [], [] + self.base_trainer = base_trainer + self.eval_model = self.base_trainer.eval_model - def update_feat(self, h, target): + def update_info(self): # 保存中间特征 + h, target = self.h, self.target self.feat_lst.append(h) self.label_lst.append(target) + self.base_trainer.update_info() - def batch(self, data, target): - h, pro1, output = self.model(data) - loss = self.criterion(output, target) - - if self.model.training: - if self.strategy == "lsd": - loss += self.lsd_loss(data, output) - elif self.strategy == "distill": - loss += self.distill_loss(output, target) - elif self.strategy == "moon": - loss += self.moon_loss(data, pro1) - else: - loss += self.fed_loss_d[self.strategy] - - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() + def clear_info(self): + self.feat_lst, self.label_lst = [], [] + self.base_trainer.clear_info() - self.update_feat(h, target) - if self.strategy == "distill": - self.logit_tracker.update(output, target) - elif self.strategy == "dyn": - self.update_prev_grads() - - iter_loss = loss.data.item() - iter_acc = self.metrics(output, target) - return iter_loss, iter_acc + def forward(self, data, target): + h, pro1, output = self.model(data) + # 更新所有可能需要的数据 + self.h, self.data, self.pro1 = h, data, pro1 + self.output, self.target = output, target + return output diff --git a/example/MOON_reproduction/main.py b/example/MOON_reproduction/main.py index 25d0a34..5ff36ed 100644 --- a/example/MOON_reproduction/main.py +++ b/example/MOON_reproduction/main.py @@ -2,29 +2,21 @@ import argparse import copy import os +from collections import defaultdict import torch import torch.nn as nn import torch.optim as optim +import MyTrainers from flearn.client import Client, datasets from flearn.client.utils import get_free_gpu_id -from flearn.common.strategy import LG -from flearn.common.utils import setup_seed +from flearn.common.utils import init_strategy, setup_seed from flearn.server import Communicator as sc from flearn.server import Server from model import GlobModel, ModelFedCon from MyClients import DistillClient, DynClient, LSDClient, MOONClient, ProxClient -from MyStrategys import CCVR, DF, Distill, Dyn -from MyTrainers import ( - AVGTrainer, - CCVRTrainer, - DistillTrainer, - DynTrainer, - LSDTrainer, - MOONTrainer, - ProxTrainer, -) +from MyStrategys import CCVR, DF, DFCCVR, Distill, Dyn from utils import get_dataloader, partition_data # 设置随机数种子 @@ -47,6 +39,7 @@ parser.add_argument("--suffix", dest="suffix", default="", type=str) parser.add_argument("--iid", dest="iid", action="store_true") parser.add_argument("--ccvr", dest="ccvr", action="store_true") +parser.add_argument("--df", dest="df", action="store_true") parser.add_argument( "--dataset_name", dest="dataset_name", @@ -62,34 +55,10 @@ args = parser.parse_args() iid = args.iid -base_strategy = args.strategy_name.lower() +strategy_name = args.strategy_name.lower() dataset_name = args.dataset_name dataset_fpath = args.dataset_fpath -# 可运行策略 -trainer_d = { - "avg": AVGTrainer, - "moon": MOONTrainer, - "prox": ProxTrainer, - "lsd": LSDTrainer, - "dyn": DynTrainer, - "distill": DistillTrainer, - "lg": AVGTrainer, - "df": AVGTrainer, -} - -client_d = { - "avg": Client, - "moon": MOONClient, - "prox": ProxClient, - "lsd": LSDClient, - "dyn": DynClient, - "distill": DistillClient, - "lg": Client, - "df": Client, -} - - model_fpath = "./client_checkpoint" if not os.path.isdir(model_fpath): os.mkdir(model_fpath) @@ -105,9 +74,7 @@ model_base = ModelFedCon("resnet50-cifar100", out_dim=256, n_classes=200) glob_model_base = GlobModel("resnet50-cifar100", out_dim=256, n_classes=200) -# 设置训练集以及策略 -trainer = trainer_d[base_strategy] if base_strategy in trainer_d.keys() else None - +# 设置策略 shared_key_layers = [ "l1.weight", "l1.bias", @@ -116,51 +83,65 @@ "l3.weight", "l3.bias", ] -strategy_d = { - "dyn": Dyn(model_fpath, copy.deepcopy(model_base).state_dict()), - "distill": Distill(model_fpath), - "df": DF(model_fpath, copy.deepcopy(model_base)), - "lg": LG(model_fpath, shared_key_layers=shared_key_layers), +custom_strategy_d = defaultdict(lambda: None) +custom_strategy_d.update( + { + "dyn": Dyn(model_fpath, h=model_base.state_dict()), + "distill": Distill(model_fpath), + } +) +strategy = init_strategy( + strategy_name, custom_strategy_d[strategy_name], model_fpath, shared_key_layers +) +kwargs = { + "h": model_base.state_dict(), + "shared_key_layers": shared_key_layers, +} +if args.ccvr and args.df: + strategy = DFCCVR(model_fpath, model_base, glob_model_base, strategy, **kwargs) +elif args.ccvr: + strategy = CCVR(model_fpath, glob_model_base, strategy, **kwargs) +elif args.df: + strategy = DF(model_fpath, model_base, strategy, **kwargs) + +# 设置 训练器-客户端 +conf_d = { + "avg": {"trainer": MyTrainers.AVGTrainer, "client": Client}, + "moon": {"trainer": MyTrainers.MOONTrainer, "client": MOONClient}, + "prox": {"trainer": MyTrainers.ProxTrainer, "client": ProxClient}, + "lsd": {"trainer": MyTrainers.LSDTrainer, "client": LSDClient}, + "dyn": {"trainer": MyTrainers.DynTrainer, "client": DynClient}, + "lg": {"trainer": MyTrainers.AVGTrainer, "client": Client}, + "distill": {"trainer": MyTrainers.DistillTrainer, "client": DistillClient}, } - -if base_strategy == "dyn": - ccvr_strategy = CCVR( - model_fpath, glob_model_base, base_strategy, h=model_base.state_dict() - ) -elif base_strategy == "lg": - ccvr_strategy = CCVR( - model_fpath, glob_model_base, base_strategy, shared_key_layers=shared_key_layers - ) -else: - ccvr_strategy = CCVR(model_fpath, glob_model_base, base_strategy) def inin_single_client(model_base, client_id): model_ = copy.deepcopy(model_base) optim_ = optim.SGD(model_.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-5) + criterion = nn.CrossEntropyLoss() trainloader, testloader, _, _ = get_dataloader( dataset_name, dataset_fpath, batch_size, batch_size, net_dataidx_map[client_id] ) + trainer = conf_d[strategy_name]["trainer"] + c_trainer = trainer(model_, optim_, criterion, device, False) + if args.ccvr: - c_trainer = CCVRTrainer( - model_, optim_, nn.CrossEntropyLoss(), device, False, strategy=base_strategy - ) - else: - c_trainer = trainer(model_, optim_, nn.CrossEntropyLoss(), device, False) + c_trainer = MyTrainers.CCVRTrainer(c_trainer) return { "trainer": c_trainer, "trainloader": trainloader, - # "testloader": testloader, - "testloader": test_dl, + "testloader": testloader, # 对应数据集的所有测试数据,未切割 "model_fname": "client{}_round_{}.pth".format(client_id, "{}"), "client_id": client_id, "model_fpath": model_fpath, "epoch": args.local_epoch, "dataset_name": dataset_name, - "strategy_name": base_strategy, + "strategy_name": strategy_name, + "strategy": copy.deepcopy(strategy), "save": False, "log": False, } @@ -204,22 +185,13 @@ def inin_single_client(model_base, client_id): client_lst = [] for client_id in range(client_numbers): c_conf = inin_single_client(model_base, client_id) - if args.ccvr: - c_conf["strategy"] = copy.deepcopy(ccvr_strategy) - elif base_strategy in strategy_d.keys(): - c_conf["strategy"] = copy.deepcopy(strategy_d[base_strategy]) - else: - c_conf["strategy_name"] = "avg" - client_lst.append(client_d[base_strategy](c_conf)) - - s_conf = {"model_fpath": model_fpath, "strategy_name": base_strategy} - if args.ccvr: - s_conf["strategy"] = copy.deepcopy(ccvr_strategy) - elif base_strategy in strategy_d.keys(): - s_conf["strategy"] = copy.deepcopy(strategy_d[base_strategy]) - else: - s_conf["strategy_name"] = "avg" + client_lst.append(conf_d[strategy_name]["client"](c_conf)) + s_conf = { + "model_fpath": model_fpath, + "strategy": strategy, + "strategy_name": copy.deepcopy(strategy), + } sc_conf = { "server": Server(s_conf), "Round": 100, @@ -235,19 +207,22 @@ def inin_single_client(model_base, client_id): kwargs = {"device": device} # 随意选取,可替换为更合适的数据集;或者生成随机数,见_create_data_randomly - trainset, testset = datasets.get_datasets("cifar100", dataset_fpath) - _, glob_testloader = datasets.get_dataloader(trainset, testset, 100, num_workers=0) - kwargs = { - "lr": 1e-2, - "T": 2, - "epoch": 1, - "method": "avg_logits", - "kd_loader": glob_testloader, - "device": device, - } + if args.df: + trainset, testset = datasets.get_datasets("cifar100", dataset_fpath) + _, glob_testloader = datasets.get_dataloader( + trainset, testset, 100, num_workers=0 + ) + kwargs = { + "lr": 1e-2, + "T": 2, + "epoch": 1, + "method": "avg_logits", + "kd_loader": glob_testloader, + "device": device, + } for ri in range(sc_conf["Round"]): - if args.ccvr: + if args.ccvr or args.df: loss, train_acc, test_acc = server_o.run(ri, k=k, **kwargs) else: loss, train_acc, test_acc = server_o.run(ri, k=k)