Skip to content

Commit

Permalink
add func server_post_processing to Strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 3, 2022
1 parent 15bbfb9 commit 4bfc579
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 48 deletions.
2 changes: 1 addition & 1 deletion flearn/common/strategy/avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def server(self, ensemble_params_lst, round_):
try:
w_glob = self.server_ensemble(agg_weight_lst, w_local_lst)
except Exception as e:
return self.server_exception(e)
self.server_exception(e)

return {"w_glob": w_glob}

Expand Down
2 changes: 1 addition & 1 deletion flearn/common/strategy/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def server(self, ensemble_params_lst, round_):

w_glob = self.server_ensemble(agg_weight_lst, w_local_lst, key_lst=key_lst)
except Exception as e:
return self.server_exception(e)
self.server_exception(e)
return {"w_glob": w_glob}
2 changes: 1 addition & 1 deletion flearn/common/strategy/lg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def server(self, ensemble_params_lst, round_):
agg_weight_lst, w_local_lst, key_lst=self.shared_key_layers
)
except Exception as e:
return self.server_exception(e)
self.server_exception(e)
return {"w_glob": w_glob}
2 changes: 1 addition & 1 deletion flearn/common/strategy/lg_reverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ def server(self, ensemble_params_lst, round_):
try:
w_glob = self.server_ensemble(agg_weight_lst, w_local_lst)
except Exception as e:
return self.server_exception(e)
self.server_exception(e)
return {"w_glob": w_glob}
16 changes: 12 additions & 4 deletions flearn/common/strategy/pav.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,21 @@ def pav_kd(self, w_local_lst, w_glob, **kwargs):

return w_glob

def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs):
w_local_lst = self.extract_lst(ensemble_params_lst, "params")
ensemble_params["w_glob"] = self.pav_kd(
w_local_lst, ensemble_params["w_glob"], **kwargs
)

return ensemble_params

def server(self, ensemble_params_lst, round_, **kwargs):
"""服务端聚合客户端模型并蒸馏
Args:
kwargs : dict
蒸馏所需参数
"""
g_shared = super().server(ensemble_params_lst, round_)
w_local_lst = self.extract_lst(ensemble_params_lst, "params")
g_shared["w_glob"] = self.pav_kd(w_local_lst, g_shared["w_glob"], **kwargs)
return g_shared
ensemble_params = super().server(ensemble_params_lst, round_)
return self.server_post_processing(
ensemble_params_lst, ensemble_params, **kwargs
)
64 changes: 24 additions & 40 deletions flearn/common/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def extract_lst(lst, key):
return list(map(lambda x: x[key], lst))

def server_pre_processing(self, ensemble_params_lst):
"""提取服务器端收到的的参数
"""服务器端对参数的预处理, 提取服务器端收到的的参数
Args:
ensemble_params_lst : list
Expand All @@ -47,6 +47,23 @@ def server_pre_processing(self, ensemble_params_lst):
w_local_lst.append(p["params"])
return agg_weight_lst, w_local_lst

def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs):
"""服务器端对参数的后处理, 对集成后的参数做进一步的更新
Args:
ensemble_params_lst : list
每个客户端发送的参数组成的列表
ensemble_params : dict
集成后的参数,准备发回客户端的参数
Returns:
dict :
同ensemble_params结构
"""
return ensemble_params

def revice_processing(self, data):
"""数据加密并转为二进制流
Expand Down Expand Up @@ -97,22 +114,9 @@ def server_exception(self, e):
Args:
e : Exception
异常消息
Returns:
dict : Dict {
'glob_params' : str
编码后的全局模型
'code' : int
状态码,
'msg' : str
状态消息,
}
"""
print(e)
print("检查客户端模型参数是否正常")
return ""
raise SystemExit("检查客户端模型参数是否正常")

def server_ensemble(self, agg_weight_lst, w_local_lst, key_lst=None):
"""服务端集成函数
Expand Down Expand Up @@ -149,8 +153,8 @@ def client(self, trainer, i, agg_weight=1.0):
"""获取客户端需要上传的模型参数及所占全局模型的权重
Args:
w_local : collections.OrderedDict
模型参数,model.state_dict()
trainer : Object
训练器
agg_weight : float
模型参数所占权重(该客户端聚合所占权重)
Expand All @@ -164,11 +168,6 @@ def client(self, trainer, i, agg_weight=1.0):
模型参数所占权重(该客户端聚合所占权重)
}
"""
w_shared = {"params": {}, "agg_weight": agg_weight}
# for k in self.shared_key_layers:
# w_shared['params'][k] = w_local[k].cpu()
# return w_shared

return NotImplemented

@abstractmethod
Expand All @@ -184,23 +183,11 @@ def server(self, ensemble_params_lst, round_):
Returns:
dict : Dict {
'glob_params' : str
编码后的全局模型
'code' : int
状态码,
'msg' : str
状态消息,
'w_glob' : collections.OrderedDict
集成后的模型参数字典
}
"""
agg_weight_lst, w_local_lst = self.server_pre_processing(ensemble_params_lst)
# N, idxs_users, w_glob = self.server_pre_processing(w_local_lst)
try:
return NotImplemented
except Exception as e:
return self.server_exception(e)
return {"w_glob": w_glob}
return NotImplemented

@abstractmethod
def client_revice(self, trainer, w_glob_b):
Expand All @@ -217,7 +204,4 @@ def client_revice(self, trainer, w_glob_b):
collections.OrderedDict
更新后的模型参数,model.state_dict()
"""
w_glob = pickle.loads(w_glob_b)
# for k in self.shared_key_layers:
# w_local[k] = w_glob[k]
return NotImplemented

0 comments on commit 4bfc579

Please sign in to comment.