Skip to content

Commit

Permalink
update MOON (update df and ccvr) and (add cli for ccvr and df)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 3, 2022
1 parent 4bfc579 commit 2548035
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 365 deletions.
4 changes: 0 additions & 4 deletions example/MOON_reproduction/MyClients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# coding: utf-8

import copy

from flearn.client import Client
Expand Down Expand Up @@ -91,10 +90,7 @@ def revice(self, i, glob_params):
self.scheduler.step()
# self.trainer.model.load_state_dict(self.w_local_bak)
self.trainer.model.load_state_dict(update_w)
self.trainer.server_model = copy.deepcopy(self.trainer.model)
self.trainer.server_model.load_state_dict(update_w)

self.trainer.server_model.eval()
self.trainer.server_state_dict = copy.deepcopy(update_w)
return {
"code": 200,
Expand Down
179 changes: 88 additions & 91 deletions example/MOON_reproduction/MyStrategys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.nn as nn
import torch.optim as optim

from flearn.common import Encrypt
from flearn.common.distiller import Distiller, KDLoss
from flearn.common.strategy import AVG

Expand All @@ -24,13 +23,14 @@ def client_revice(self, trainer, data_glob_d):
return w_local, logits_glob

def server(self, ensemble_params_lst, round_):
g_shared = super(Distill, self).server(ensemble_params_lst, round_)
ensemble_params = super(Distill, self).server(ensemble_params_lst, round_)

logits_lst = self.extract_lst(ensemble_params_lst, "logits")
g_shared["logits_glob"] = self.aggregate_logits(logits_lst)
return g_shared
ensemble_params["logits_glob"] = self.aggregate_logits(logits_lst)
return ensemble_params

def aggregate_logits(self, logits_lst):
@staticmethod
def aggregate_logits(logits_lst):
user_logits = 0
for item in logits_lst:
user_logits += item
Expand All @@ -56,12 +56,42 @@ def dyn_f(self, w_glob, w_local_lst):
for k in self.h.keys():
w_glob[k] = w_glob[k] - self.alpha * self.h[k]
self.theta = w_glob
return w_glob

def server_post_processing(self, ensemble_params_lst, ensemble_params):
w_local_lst = self.extract_lst(ensemble_params_lst, "params")
ensemble_params["w_glob"] = self.dyn_f(ensemble_params["w_glob"], w_local_lst)
return ensemble_params

def server(self, ensemble_params_lst, round_):
ensemble_params = super().server(ensemble_params_lst, round_)
return self.server_post_processing(ensemble_params_lst, ensemble_params)


class ParentStrategy(AVG):
def __init__(self, model_fpath, strategy=None, **kwargs):
self.strategy = strategy
if self.strategy != None:
key_lst = list(self.strategy.__dict__.keys())
for k, v in kwargs.items():
if k in key_lst:
self.__dict__[k] = v
super().__init__(model_fpath)

def client(self, trainer, agg_weight=1.0):
if self.strategy != None:
return self.strategy.client(trainer, agg_weight)
return super().client(trainer, agg_weight)

def server(self, ensemble_params_lst, round_):
g_shared = super(Dyn, self).server(ensemble_params_lst, round_)
_, w_local_lst = self.server_pre_processing(ensemble_params_lst)
self.dyn_f(g_shared["w_glob"], w_local_lst)
return g_shared
if self.strategy != None:
return self.strategy.server(ensemble_params_lst, round_)
return super().server(ensemble_params_lst, round_)

def client_revice(self, trainer, data_glob_d):
if self.strategy != None:
return self.strategy.client_revice(trainer, data_glob_d)
return super().client_revice(trainer, data_glob_d)


class DFDistiller(Distiller):
Expand Down Expand Up @@ -111,7 +141,8 @@ def multi(
soft_target_lst = []
for teacher in self.teacher_lst:
with torch.no_grad():
soft_target_lst.append(teacher(x))
_, _, soft_target = teacher(x)
soft_target_lst.append(soft_target)

loss = self.multi_loss(method, soft_target_lst, output)

Expand All @@ -122,18 +153,13 @@ def multi(
return student.state_dict()


class DF(AVG):
"""
Ensemble distillation for robust model fusion in federated learning
[1] Lin T, Kong L, Stich S U, et al. Ensemble distillation for robust model fusion in federated learning[J]. arXiv preprint arXiv:2006.07242, 2020.
"""

def __init__(self, model_fpath, model_base):
super().__init__(model_fpath)
class DF(ParentStrategy):
def __init__(self, model_fpath, model_base, strategy=None, **kwargs):
super().__init__(model_fpath, strategy, **kwargs)
self.model_base = model_base

def ensemble_w(self, ensemble_params_lst, w_glob, **kwargs):
def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs):
w_glob = ensemble_params["w_glob"]
agg_weight_lst, w_local_lst = self.server_pre_processing(ensemble_params_lst)

teacher_lst = []
Expand All @@ -144,19 +170,21 @@ def ensemble_w(self, ensemble_params_lst, w_glob, **kwargs):
self.model_base.load_state_dict(w_glob)
student = copy.deepcopy(self.model_base)

self.distiller = DFDistiller(
kwargs.pop("kd_loader"),
kwargs.pop("device"),
kd_loss=KDLoss(kwargs.pop("T")),
kd_loader, device = kwargs.pop("kd_loader"), kwargs.pop("device")
temperature = kwargs.pop("T")
distiller = DFDistiller(
kd_loader,
device,
kd_loss=KDLoss(temperature),
)

molecular = np.sum(agg_weight_lst)
weight_lst = [w / molecular for w in agg_weight_lst]
# agg_weight_lst:应该依照每个模型在验证集上的性能来进行分配
w_glob = self.distiller.multi(
ensemble_params["w_glob"] = distiller.multi(
teacher_lst, student, kwargs.pop("method"), weight_lst=weight_lst, **kwargs
)
return w_glob
return ensemble_params

def server(self, ensemble_params_lst, round_, **kwargs):
"""
Expand All @@ -169,11 +197,10 @@ def server(self, ensemble_params_lst, round_, **kwargs):
"kd_loader": 蒸馏数据集,仅需输入,无需标签
}
"""
g_shared = super(DF, self).server(ensemble_params_lst, round_)
g_shared["w_glob"] = self.ensemble_w(
ensemble_params_lst, g_shared["w_glob"], **kwargs
ensemble_params = super().server(ensemble_params_lst, round_)
return self.server_post_processing(
ensemble_params_lst, ensemble_params, **kwargs
)
return g_shared


class ReTrain:
Expand Down Expand Up @@ -209,25 +236,10 @@ def run(self, student, lr=0.01):
return student.state_dict()


class CCVR(Distill, Dyn, DF):
def __init__(
self,
model_fpath,
glob_model_base,
model_base=None,
strategy=None,
h=None,
shared_key_layers=None,
):
# kwargs = {"model_fpath": model_fpath, "model_base": model_base, "h": h}
# super().__init__(**kwargs)
self.model_fpath = model_fpath
self.model_base = model_base
self.h = h
self.glob_model = glob_model_base
self.strategy = strategy
self.shared_key_layers = shared_key_layers
self.encrypt = Encrypt()
class CCVR(ParentStrategy):
def __init__(self, model_fpath, glob_model_base, strategy=None, **kwargs):
super().__init__(model_fpath, strategy, **kwargs)
self.glob_model_base = glob_model_base

@staticmethod
def client_mean_feat(feat_lst, label_lst):
Expand Down Expand Up @@ -256,16 +268,7 @@ def client_mean_feat(feat_lst, label_lst):
return upload_d

def client(self, trainer, agg_weight=1.0):
if self.strategy == "lg":
w_shared = {"agg_weight": agg_weight}
w_local = trainer.weight
w_shared["params"] = {k: w_local[k].cpu() for k in self.shared_key_layers}
else:
w_shared = super(CCVR, self).client(trainer, agg_weight)

if self.strategy == "distill":
w_shared["logits"] = trainer.logit_tracker.avg()

w_shared = super().client(trainer, agg_weight)
w_shared["fd"] = self.client_mean_feat(trainer.feat_lst, trainer.label_lst)
return w_shared

Expand Down Expand Up @@ -312,47 +315,41 @@ def server_mean_feat(self, fd_lst):

return fd_d

def server(self, ensemble_params_lst, round_, **kwargs):
agg_weight_lst, w_local_lst = self.server_pre_processing(ensemble_params_lst)
try:
w_glob = self.server_ensemble(agg_weight_lst, w_local_lst)
except Exception as e:
return self.server_exception(e)
g_shared = {"w_glob": w_glob}

if self.strategy == "dyn":
_, w_local_lst = self.server_pre_processing(ensemble_params_lst)
self.dyn_f(g_shared["w_glob"], w_local_lst)
elif self.strategy == "distill":
logits_lst = self.extract_lst(ensemble_params_lst, "logits")
g_shared["logits_glob"] = self.aggregate_logits(logits_lst)
elif self.strategy == "df":
g_shared["w_glob"] = self.ensemble_w(
ensemble_params_lst, g_shared["w_glob"], **kwargs
)

def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs):
# 特征参数提取
fd_lst = self.extract_lst(ensemble_params_lst, "fd")
fd_d = self.server_mean_feat(fd_lst)

# 重新训练分类器
self.retrainer = ReTrain(fd_d, kwargs["device"])
self.glob_model = self.load_model(self.glob_model, g_shared["w_glob"])
w_train = self.retrainer.run(self.glob_model)
self.glob_model_base = self.load_model(
self.glob_model_base, ensemble_params["w_glob"]
)
w_train = self.retrainer.run(self.glob_model_base)

for k in w_train.keys():
g_shared["w_glob"][k] = w_train[k].cpu()
ensemble_params["w_glob"][k] = w_train[k].cpu()

return g_shared
return ensemble_params

def client_revice(self, trainer, data_glob_d):
w_local = trainer.weight
w_glob = data_glob_d["w_glob"]
for k in w_glob.keys():
w_local[k] = w_glob[k]
def server(self, ensemble_params_lst, round_, **kwargs):
ensemble_params = super().server(ensemble_params_lst, round_)
return self.server_post_processing(
ensemble_params_lst, ensemble_params, **kwargs
)

if self.strategy == "distill":
logits_glob = data_glob_d["logits_glob"]
return w_local, logits_glob

return w_local
class DFCCVR(CCVR):
def __init__(
self, model_fpath, model_base, glob_model_base, strategy=None, **kwargs
):
super().__init__(model_fpath, glob_model_base, strategy, **kwargs)
self.model_base = model_base
self.df = DF(model_fpath, model_base, strategy, **kwargs)

def server(self, ensemble_params_lst, round_, **kwargs):
# 先DF后CCVR
ensemble_params = self.df.server(ensemble_params_lst, round_, **kwargs)
return self.server_post_processing(
ensemble_params_lst, ensemble_params, **kwargs
)
Loading

0 comments on commit 2548035

Please sign in to comment.