From 33d804dd6dc8444c7ac401412e91da7b8e778061 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Thu, 10 Oct 2024 22:55:26 +0800 Subject: [PATCH] update --- baselines/fedbabu/fedbabu/client.py | 8 +-- baselines/fedbabu/fedbabu/task.py | 108 +++++++++++++++++++++++----- 2 files changed, 94 insertions(+), 22 deletions(-) diff --git a/baselines/fedbabu/fedbabu/client.py b/baselines/fedbabu/fedbabu/client.py index 416173696c5..741c1807f9b 100644 --- a/baselines/fedbabu/fedbabu/client.py +++ b/baselines/fedbabu/fedbabu/client.py @@ -3,7 +3,7 @@ from flwr.client import NumPyClient, ClientApp from fedbabu.task import ( - Net, + MobileNetCifar, DEVICE, load_data, get_weights, @@ -33,7 +33,7 @@ def evaluate(self, parameters, config): def client_fn(cid): # Load model and data - net = Net().to(DEVICE) + net = MobileNetCifar().to(DEVICE) trainloader, valloader = load_data(int(cid), 2) # Return Client instance @@ -41,6 +41,4 @@ def client_fn(cid): # Flower ClientApp -app = ClientApp( - client_fn, -) +app = ClientApp(client_fn) diff --git a/baselines/fedbabu/fedbabu/task.py b/baselines/fedbabu/fedbabu/task.py index 981311aa6ba..82b0bba7da0 100644 --- a/baselines/fedbabu/fedbabu/task.py +++ b/baselines/fedbabu/fedbabu/task.py @@ -9,34 +9,108 @@ from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import DirichletPartitioner +from flwr_datasets.preprocessor import Merger DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -class Net(nn.Module): - """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" +''' +Modified from https://github.com/jhoon-oh/FedBABU/blob/master/models/Nets.py +MobileNet in PyTorch. +See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" +for more details. +''' + + +class Block(nn.Module): + '''Depthwise conv + Pointwise conv''' + + def __init__(self, in_planes, out_planes, stride=1): + super(Block, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, + in_planes, + kernel_size=3, + stride=stride, + padding=1, + groups=in_planes, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=False) + self.conv2 = nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False + ) + self.bn2 = nn.BatchNorm2d(out_planes, track_running_stats=False) - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + return out + + +class MobileNetCifar(nn.Module): + # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 + cfg = [ + 64, + (128, 2), + 128, + (256, 2), + 256, + (512, 2), + 512, + 512, + 512, + 512, + 512, + (1024, 2), + 1024, + ] + + def __init__(self, num_classes=10): + super(MobileNetCifar, self).__init__() + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(32, track_running_stats=False) + self.layers = self._make_layers(in_planes=32) + self.linear = nn.Linear(1024, num_classes) + + def _make_layers(self, in_planes): + layers = [] + for x in self.cfg: + out_planes = x if isinstance(x, int) else x[0] + stride = 1 if isinstance(x, int) else x[1] + layers.append(Block(in_planes, out_planes, stride)) + in_planes = out_planes + return nn.Sequential(*layers) def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return self.fc3(x) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layers(out) + out = F.avg_pool2d(out, 2) + out = out.view(out.size(0), -1) + logits = self.linear(out) + + return logits + + def extract_features(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layers(out) + out = F.avg_pool2d(out, 2) + out = out.view(out.size(0), -1) + + return out def load_data(partition_id, num_partitions): """Load partition CIFAR10 data.""" - fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}) + partitioner = DirichletPartitioner( + num_partitions=num_partitions, partition_by="label", alpha=0.1 + ) + fds = FederatedDataset( + dataset="cifar10", + partitioners={"train": partitioner}, + preprocessor=Merger({"train": ("train", "test")}), + ) partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42)