Skip to content

wnma3mz/flearn

Repository files navigation

联邦学习框架

Pypi

Quickstart

  1. 下载最新的release版本 并使用pip安装。或手动下载源码,在当前目录进行编译 python setup.py sdist bdist_wheel
  2. 切换至运行目录 cd example/mnist_cifar/
  3. 运行 python main.py --strategy_name avg --dataset_name mnist dataset_fpath 数据集路径

详细解释见 example/mnist_cifar中的README.md

进阶1——复现LG-FedAVG

  • 修改 Client.py,以及如何配置共享层

README.md

进阶2——复现FedProx

  • 修改训练器,以运用至更多任务与模型

README.md

进阶3——复现FedPAV

  • 修改客户端以及服务器端,以适用于FedPAV策略

README.md

支持策略

split-learning可见README.md,尚存在loss爆炸问题。

TODO

框架图

CFL

  • 对于active_client当前是随机选择客户端,可以使用更加先进的客户端选择算法
  • 对于Server端的evaluate,每种算法的目标大致可以分为两种,1) 获得一个模型,这个模型在某个数据集上能够取得很好的性能;2) 每个客户端获得不同的模型,每个客户端的模型在对应的数据集上能够取得很好的性能。因此这里分两部分处理
    • 仅在服务器端对单个模型在测试集上进行测试,直接获得测试结果(为区分,在log中会对这部分测试结果加上[Server])
    • 分别在所有/指定客户端中对客户端上的测试集上进行测试,在服务器端对结果取平均
    • 两种测试均进行
  • 对于辅助类
    • Distiller是为了进行知识蒸馏的操作,注:这里主要是对蒸馏方式进行整理,损失函数采用默认的KL散度
    • Encrypt是为了对通信参数进行加密,注:这里仅作base编解码操作
    • Logger是为了输出日志文件的数据
  • 对于Trainer,部分算法是修改loss函数,因此这里通过增加函数fed_loss,方便快速修改
  • 对于Strategy,部分算法需要修改上传以及下载参数(不仅上传模型参数,还有可能有其他信息),所以将Strategy处理为三个函数
    • client(仅对客户端有效):处理上传参数
    • server(仅对服务器端有效):服务器端处理所有客户端上传的参数
    • client_revice(仅对客户端有效):接收服务器端的参数并进行处理

工作流

  1. 服务器(Server)发送训练指令至各个客户端(Client)进行训练 (Server->Comm(S)->Comm(C)->Client);模拟实验时,(Server->Client)
  2. Client根据配置好的训练器(Trainer)进行训练,训练完成后返回指令至Server
  3. Server发送上传指令至Client,Client根据配置好的策略(Strategy),准备好上传的参数并进行上传,即Server发送指令后收到Client上传的参数
  4. Server根据预先配好的Strategy对参数进行聚合
  5. Server发送接收指令至Client,此时把参数发回至Client,Client根据配置好的Strategy进行接收
  6. 若Server继续发送测试指令至Client,Client还要对更新后的模型进行验证,并返回验证后的结果至Server。否则,Server直接进行验证

P.S.

  • Trainer中一般是需要配置联邦的损失函数 fed_loss,主要作用是为了防止灾难性遗忘
  • Distiller可以看作 fed_loss,也可以看作聚合策略的一种,所以可能会在 Strategy进行调用优化模型参数
  • Strategy其实可以分为Server和Client两个部分,其中Client有两个函数(上传和接收)。此处是将策略看作一个整体,即Client和Server都调用同一个Strategy