diff --git a/flearn/common/strategy/avg.py b/flearn/common/strategy/avg.py index 36c1b22..dc2907d 100644 --- a/flearn/common/strategy/avg.py +++ b/flearn/common/strategy/avg.py @@ -24,7 +24,7 @@ def server(self, ensemble_params_lst, round_): try: w_glob = self.server_ensemble(agg_weight_lst, w_local_lst) except Exception as e: - return self.server_exception(e) + self.server_exception(e) return {"w_glob": w_glob} diff --git a/flearn/common/strategy/bn.py b/flearn/common/strategy/bn.py index 9770d15..0ec7953 100644 --- a/flearn/common/strategy/bn.py +++ b/flearn/common/strategy/bn.py @@ -28,5 +28,5 @@ def server(self, ensemble_params_lst, round_): w_glob = self.server_ensemble(agg_weight_lst, w_local_lst, key_lst=key_lst) except Exception as e: - return self.server_exception(e) + self.server_exception(e) return {"w_glob": w_glob} diff --git a/flearn/common/strategy/lg.py b/flearn/common/strategy/lg.py index 151ca48..4b57869 100644 --- a/flearn/common/strategy/lg.py +++ b/flearn/common/strategy/lg.py @@ -35,5 +35,5 @@ def server(self, ensemble_params_lst, round_): agg_weight_lst, w_local_lst, key_lst=self.shared_key_layers ) except Exception as e: - return self.server_exception(e) + self.server_exception(e) return {"w_glob": w_glob} diff --git a/flearn/common/strategy/lg_reverse.py b/flearn/common/strategy/lg_reverse.py index 7d81465..fa4f2fa 100644 --- a/flearn/common/strategy/lg_reverse.py +++ b/flearn/common/strategy/lg_reverse.py @@ -38,5 +38,5 @@ def server(self, ensemble_params_lst, round_): try: w_glob = self.server_ensemble(agg_weight_lst, w_local_lst) except Exception as e: - return self.server_exception(e) + self.server_exception(e) return {"w_glob": w_glob} diff --git a/flearn/common/strategy/pav.py b/flearn/common/strategy/pav.py index aa2b609..009a8a4 100644 --- a/flearn/common/strategy/pav.py +++ b/flearn/common/strategy/pav.py @@ -158,13 +158,21 @@ def pav_kd(self, w_local_lst, w_glob, **kwargs): return w_glob + def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs): + w_local_lst = self.extract_lst(ensemble_params_lst, "params") + ensemble_params["w_glob"] = self.pav_kd( + w_local_lst, ensemble_params["w_glob"], **kwargs + ) + + return ensemble_params + def server(self, ensemble_params_lst, round_, **kwargs): """服务端聚合客户端模型并蒸馏 Args: kwargs : dict 蒸馏所需参数 """ - g_shared = super().server(ensemble_params_lst, round_) - w_local_lst = self.extract_lst(ensemble_params_lst, "params") - g_shared["w_glob"] = self.pav_kd(w_local_lst, g_shared["w_glob"], **kwargs) - return g_shared + ensemble_params = super().server(ensemble_params_lst, round_) + return self.server_post_processing( + ensemble_params_lst, ensemble_params, **kwargs + ) diff --git a/flearn/common/strategy/strategy.py b/flearn/common/strategy/strategy.py index 0fc97f1..6176718 100644 --- a/flearn/common/strategy/strategy.py +++ b/flearn/common/strategy/strategy.py @@ -28,7 +28,7 @@ def extract_lst(lst, key): return list(map(lambda x: x[key], lst)) def server_pre_processing(self, ensemble_params_lst): - """提取服务器端收到的的参数 + """服务器端对参数的预处理, 提取服务器端收到的的参数 Args: ensemble_params_lst : list @@ -47,6 +47,23 @@ def server_pre_processing(self, ensemble_params_lst): w_local_lst.append(p["params"]) return agg_weight_lst, w_local_lst + def server_post_processing(self, ensemble_params_lst, ensemble_params, **kwargs): + """服务器端对参数的后处理, 对集成后的参数做进一步的更新 + + Args: + ensemble_params_lst : list + 每个客户端发送的参数组成的列表 + + ensemble_params : dict + 集成后的参数,准备发回客户端的参数 + + + Returns: + dict : + 同ensemble_params结构 + """ + return ensemble_params + def revice_processing(self, data): """数据加密并转为二进制流 @@ -97,22 +114,9 @@ def server_exception(self, e): Args: e : Exception 异常消息 - - Returns: - dict : Dict { - 'glob_params' : str - 编码后的全局模型 - - 'code' : int - 状态码, - - 'msg' : str - 状态消息, - } """ print(e) - print("检查客户端模型参数是否正常") - return "" + raise SystemExit("检查客户端模型参数是否正常") def server_ensemble(self, agg_weight_lst, w_local_lst, key_lst=None): """服务端集成函数 @@ -149,8 +153,8 @@ def client(self, trainer, i, agg_weight=1.0): """获取客户端需要上传的模型参数及所占全局模型的权重 Args: - w_local : collections.OrderedDict - 模型参数,model.state_dict() + trainer : Object + 训练器 agg_weight : float 模型参数所占权重(该客户端聚合所占权重) @@ -164,11 +168,6 @@ def client(self, trainer, i, agg_weight=1.0): 模型参数所占权重(该客户端聚合所占权重) } """ - w_shared = {"params": {}, "agg_weight": agg_weight} - # for k in self.shared_key_layers: - # w_shared['params'][k] = w_local[k].cpu() - # return w_shared - return NotImplemented @abstractmethod @@ -184,23 +183,11 @@ def server(self, ensemble_params_lst, round_): Returns: dict : Dict { - 'glob_params' : str - 编码后的全局模型 - - 'code' : int - 状态码, - - 'msg' : str - 状态消息, + 'w_glob' : collections.OrderedDict + 集成后的模型参数字典 } """ - agg_weight_lst, w_local_lst = self.server_pre_processing(ensemble_params_lst) - # N, idxs_users, w_glob = self.server_pre_processing(w_local_lst) - try: - return NotImplemented - except Exception as e: - return self.server_exception(e) - return {"w_glob": w_glob} + return NotImplemented @abstractmethod def client_revice(self, trainer, w_glob_b): @@ -217,7 +204,4 @@ def client_revice(self, trainer, w_glob_b): collections.OrderedDict 更新后的模型参数,model.state_dict() """ - w_glob = pickle.loads(w_glob_b) - # for k in self.shared_key_layers: - # w_local[k] = w_glob[k] return NotImplemented