Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Oct 10, 2024
1 parent 91fe97d commit 33d804d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 22 deletions.
8 changes: 3 additions & 5 deletions baselines/fedbabu/fedbabu/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from flwr.client import NumPyClient, ClientApp

from fedbabu.task import (
Net,
MobileNetCifar,
DEVICE,
load_data,
get_weights,
Expand Down Expand Up @@ -33,14 +33,12 @@ def evaluate(self, parameters, config):

def client_fn(cid):
# Load model and data
net = Net().to(DEVICE)
net = MobileNetCifar().to(DEVICE)
trainloader, valloader = load_data(int(cid), 2)

# Return Client instance
return FlowerClient(net, trainloader, valloader).to_client()


# Flower ClientApp
app = ClientApp(
client_fn,
)
app = ClientApp(client_fn)
108 changes: 91 additions & 17 deletions baselines/fedbabu/fedbabu/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,108 @@
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from flwr_datasets.preprocessor import Merger

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
'''
Modified from https://github.com/jhoon-oh/FedBABU/blob/master/models/Nets.py
MobileNet in PyTorch.
See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
for more details.
'''


class Block(nn.Module):
'''Depthwise conv + Pointwise conv'''

def __init__(self, in_planes, out_planes, stride=1):
super(Block, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
in_planes,
kernel_size=3,
stride=stride,
padding=1,
groups=in_planes,
bias=False,
)
self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=False)
self.conv2 = nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False
)
self.bn2 = nn.BatchNorm2d(out_planes, track_running_stats=False)

def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
return out


class MobileNetCifar(nn.Module):
# (128,2) means conv planes=128, conv stride=2, by default conv stride=1
cfg = [
64,
(128, 2),
128,
(256, 2),
256,
(512, 2),
512,
512,
512,
512,
512,
(1024, 2),
1024,
]

def __init__(self, num_classes=10):
super(MobileNetCifar, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32, track_running_stats=False)
self.layers = self._make_layers(in_planes=32)
self.linear = nn.Linear(1024, num_classes)

def _make_layers(self, in_planes):
layers = []
for x in self.cfg:
out_planes = x if isinstance(x, int) else x[0]
stride = 1 if isinstance(x, int) else x[1]
layers.append(Block(in_planes, out_planes, stride))
in_planes = out_planes
return nn.Sequential(*layers)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layers(out)
out = F.avg_pool2d(out, 2)
out = out.view(out.size(0), -1)
logits = self.linear(out)

return logits

def extract_features(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layers(out)
out = F.avg_pool2d(out, 2)
out = out.view(out.size(0), -1)

return out


def load_data(partition_id, num_partitions):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
partitioner = DirichletPartitioner(
num_partitions=num_partitions, partition_by="label", alpha=0.1
)
fds = FederatedDataset(
dataset="cifar10",
partitioners={"train": partitioner},
preprocessor=Merger({"train": ("train", "test")}),
)
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
Expand Down

0 comments on commit 33d804d

Please sign in to comment.