Skip to content

Commit

Permalink
Make communications pluggable in C++ SDK (#2763)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Apr 17, 2024
1 parent 3af7199 commit 58c47b5
Show file tree
Hide file tree
Showing 33 changed files with 12,110 additions and 5,872 deletions.
12 changes: 7 additions & 5 deletions .github/workflows/cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,18 @@ jobs:
cmake -DUSE_LOCAL_FLWR=ON -S . -B build
cmake --build build
pip install ../..
timeout 2m python server.py &
pid=$!
timeout 3m flower-superlink --insecure &
sleep 10
timeout 2m build/flwr_client 0 127.0.0.1:9092 &
sleep 3
build/flwr_client 0 127.0.0.1:8080 &
timeout 2m build/flwr_client 1 127.0.0.1:9092 &
sleep 3
build/flwr_client 1 127.0.0.1:8080 &
flower-server-app server:app --insecure &
pid=$!
wait $pid
res=$?
if [[ "$res" = "0" ]];
then echo "Training worked correctly";
then echo "Training worked correctly" && exit 0;
else echo "Training had an issue" && exit 1;
fi
Expand Down
1 change: 0 additions & 1 deletion examples/quickstart-cpp/fedavg_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def aggregate_evaluate(
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
print(results[0][1])
loss_aggregated = weighted_loss_avg(
[
(
Expand Down
20 changes: 13 additions & 7 deletions examples/quickstart-cpp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
import numpy as np
from fedavg_cpp import FedAvgCpp, weights_to_parameters

model_size = 2
initial_weights = [
np.array([1.0, 2.0], dtype=np.float64),
np.array([3.0], dtype=np.float64),
]
initial_parameters = weights_to_parameters(initial_weights)
strategy = FedAvgCpp(initial_parameters=initial_parameters)

app = fl.server.ServerApp(
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

# Start Flower server for three rounds of federated learning
if __name__ == "__main__":
model_size = 2
initial_weights = [
np.array([1.0, 2.0], dtype=np.float64),
np.array([3.0], dtype=np.float64),
]
initial_parameters = weights_to_parameters(initial_weights)
strategy = FedAvgCpp(initial_parameters=initial_parameters)
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
Expand Down
26 changes: 6 additions & 20 deletions examples/quickstart-cpp/src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@
#include "start.h"

int main(int argc, char **argv) {
if (argc != 3 && argc != 4) {
std::cout << "Client takes three mandatory arguments and one optional as "
"follows: "
<< std::endl;
std::cout << "./client CLIENT_ID SERVER_URL [GRPC_MODE]" << std::endl;
std::cout
<< "GRPC_MODE is optional and can be either 'bidi' (default) or 'rere'."
<< std::endl;
std::cout << "Example: ./flwr_client 0 '127.0.0.1:8080' bidi" << std::endl;
std::cout << "This is the same as: ./flwr_client 0 '127.0.0.1:8080'"
<< std::endl;
if (argc != 3) {
std::cout << "Client takes 2 mandatory arguments as follows: " << std::endl;
std::cout << "./client CLIENT_ID SERVER_URL" << std::endl;
std::cout << "Example: ./flwr_client 0 '127.0.0.1:8080'" << std::endl;
return 0;
}

Expand Down Expand Up @@ -45,15 +38,8 @@ int main(int argc, char **argv) {
// Define a server address
std::string server_add = SERVER_URL;

if (argc == 4 && std::string(argv[3]) == "rere") {
std::cout << "Starting rere client" << std::endl;
// Start rere client
start::start_rere_client(server_add, &client);
} else {
std::cout << "Starting bidi client" << std::endl;
// Start bidi client
start::start_client(server_add, &client);
}
std::cout << "Starting rere client" << std::endl;
start::start_client(server_add, &client);

return 0;
}
1 change: 1 addition & 0 deletions src/cc/flwr/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
build/
.clangd
*.bak
2 changes: 2 additions & 0 deletions src/cc/flwr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ GENERATE_AND_COPY(transport)
GENERATE_AND_COPY(node)
GENERATE_AND_COPY(task)
GENERATE_AND_COPY(fleet)
GENERATE_AND_COPY(error)
GENERATE_AND_COPY(recordset)

add_library(flwr_grpc_proto STATIC ${ALL_PROTO_FILES})

Expand Down
30 changes: 30 additions & 0 deletions src/cc/flwr/include/communicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef COMMUNICATOR_H
#define COMMUNICATOR_H

#include "flwr/proto/fleet.pb.h"
#include <chrono>
#include <optional>

class Communicator {
public:
virtual bool send_create_node(flwr::proto::CreateNodeRequest request,
flwr::proto::CreateNodeResponse *response) = 0;

virtual bool send_delete_node(flwr::proto::DeleteNodeRequest request,
flwr::proto::DeleteNodeResponse *response) = 0;

virtual bool
send_pull_task_ins(flwr::proto::PullTaskInsRequest request,
flwr::proto::PullTaskInsResponse *response) = 0;

virtual bool
send_push_task_res(flwr::proto::PushTaskResRequest request,
flwr::proto::PushTaskResResponse *response) = 0;
};

void create_node(Communicator *communicator);
void delete_node(Communicator *communicator);
void send(Communicator *communicator, flwr::proto::TaskRes task_res);
std::optional<flwr::proto::TaskIns> receive(Communicator *communicator);

#endif
27 changes: 27 additions & 0 deletions src/cc/flwr/include/flwr/proto/error.grpc.pb.cc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions src/cc/flwr/include/flwr/proto/error.grpc.pb.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 58c47b5

Please sign in to comment.