forked from shaoxiongji/federated-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_fed.py
114 lines (100 loc) · 4.18 KB
/
main_fed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img
if __name__ == '__main__':
# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
# load dataset and split users
if args.dataset == 'mnist':
trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
dict_users = mnist_iid(dataset_train, args.num_users)
else:
dict_users = mnist_noniid(dataset_train, args.num_users)
elif args.dataset == 'cifar':
trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
if args.iid:
dict_users = cifar_iid(dataset_train, args.num_users)
else:
exit('Error: only consider IID setting in CIFAR10')
else:
exit('Error: unrecognized dataset')
img_size = dataset_train[0][0].shape
# build model
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob)
net_glob.train()
# copy weights
w_glob = net_glob.state_dict()
# training
loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0
net_best = None
best_loss = None
val_acc_list, net_list = [], []
if args.all_clients:
print("Aggregation over all clients")
w_locals = [w_glob for i in range(args.num_users)]
for iter in range(args.epochs):
loss_locals = []
if not args.all_clients:
w_locals = []
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
if args.all_clients:
w_locals[idx] = copy.deepcopy(w)
else:
w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))
# update global weights
w_glob = FedAvg(w_locals)
# copy weight to net_glob
net_glob.load_state_dict(w_glob)
# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
loss_train.append(loss_avg)
# plot loss curve
plt.figure()
plt.plot(range(len(loss_train)), loss_train)
plt.ylabel('train_loss')
plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))
# testing
net_glob.eval()
acc_train, loss_train = test_img(net_glob, dataset_train, args)
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))