Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/adap/flower into cpp-better…
Browse files Browse the repository at this point in the history
…-comms
  • Loading branch information
charlesbvll committed Mar 28, 2024
2 parents 8057db1 + 531e0e3 commit 08b4baf
Show file tree
Hide file tree
Showing 53 changed files with 833 additions and 282 deletions.
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: local
hooks:
- id: format-code
name: Format Code
entry: ./dev/format.sh
language: script
# Ensures the script runs from the repository root:
pass_filenames: false
stages: [commit]

- id: run-tests
name: Run Tests
entry: ./dev/test.sh
language: script
# Ensures the script runs from the repository root:
pass_filenames: false
stages: [commit]
2 changes: 1 addition & 1 deletion datasets/doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
author = "The Flower Authors"

# The full version, including alpha/beta/rc tags
release = "0.0.2"
release = "0.1.0"


# -- General configuration ---------------------------------------------------
Expand Down
27 changes: 27 additions & 0 deletions doc/source/contributor-tutorial-get-started-as-a-contributor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,33 @@ Run Linters and Tests

$ ./dev/test.sh

Add a pre-commit hook
~~~~~~~~~~~~~~~~~~~~~

Developers may integrate a pre-commit hook into their workflow utilizing the `pre-commit <https://pre-commit.com/#install>`_ library. The pre-commit hook is configured to execute two primary operations: ``./dev/format.sh`` and ``./dev/test.sh`` scripts.

There are multiple ways developers can use this:

1. Install the pre-commit hook to your local git directory by simply running:

::
$ pre-commit install

- Each ``git commit`` will trigger the execution of formatting and linting/test scripts.
- If in a hurry, bypass the hook using ``--no-verify`` with the ``git commit`` command.
::
$ git commit --no-verify -m "Add new feature"
2. For developers who prefer not to install the hook permanently, it is possible to execute a one-time check prior to committing changes by using the following command:

::

$ pre-commit run --all-files
This executes the formatting and linting checks/tests on all the files without modifying the default behavior of ``git commit``.

Run Github Actions (CI) locally
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 3 additions & 3 deletions examples/app-pytorch/client_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ def hello_world_mod(msg, ctx, call_next) -> Message:
@app.train()
def train(msg: Message, ctx: Context):
print("`train` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)


@app.evaluate()
def eval(msg: Message, ctx: Context):
print("`evaluate` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)


@app.query()
def query(msg: Message, ctx: Context):
print("`query` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)
36 changes: 23 additions & 13 deletions examples/app-pytorch/server_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Message,
MessageType,
Metrics,
DEFAULT_TTL,
)
from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres
from flwr.server import Driver, History
Expand Down Expand Up @@ -89,7 +90,7 @@ def main(driver: Driver, context: Context) -> None:
message_type=MessageType.TRAIN,
dst_node_id=node_id,
group_id=str(server_round),
ttl="",
ttl=DEFAULT_TTL,
)
messages.append(message)

Expand All @@ -102,15 +103,19 @@ def main(driver: Driver, context: Context) -> None:
all_replies: List[Message] = []
while True:
replies = driver.pull_messages(message_ids=message_ids)
print(f"Got {len(replies)} results")
for res in replies:
print(f"Got 1 {'result' if res.has_content() else 'error'}")
all_replies += replies
if len(all_replies) == len(message_ids):
break
print("Pulling messages...")
time.sleep(3)

# Collect correct results
# Filter correct results
all_fitres = [
recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies
recordset_to_fitres(msg.content, keep_input=True)
for msg in all_replies
if msg.has_content()
]
print(f"Received {len(all_fitres)} results")

Expand All @@ -127,16 +132,21 @@ def main(driver: Driver, context: Context) -> None:
)
metrics_results.append((fitres.num_examples, fitres.metrics))

# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated
if len(weights_results) > 0:
# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated

# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
else:
print(
f"Round {server_round} got {len(weights_results)} results. Skipping aggregation..."
)

# Slow down the start of the next round
time.sleep(sleep_time)
Expand Down
12 changes: 10 additions & 2 deletions examples/app-pytorch/server_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import time

import flwr as fl
from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet
from flwr.common import (
Context,
NDArrays,
Message,
MessageType,
Metrics,
RecordSet,
DEFAULT_TTL,
)
from flwr.server import Driver


Expand All @@ -30,7 +38,7 @@ def main(driver: Driver, context: Context) -> None:
message_type=MessageType.TRAIN,
dst_node_id=node_id,
group_id=str(server_round),
ttl="",
ttl=DEFAULT_TTL,
)
messages.append(message)

Expand Down
1 change: 1 addition & 0 deletions examples/llm-flowertune/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ bitsandbytes==0.41.3
scipy==1.11.2
peft==0.4.0
fschat[model_worker,webui]==0.2.35
transformers==4.38.1
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ check-wheel-contents = "==0.4.0"
GitPython = "==3.1.32"
PyGithub = "==2.1.1"
licensecheck = "==2024"
pre-commit = "==3.5.0"

[tool.isort]
line_length = 88
Expand Down
8 changes: 8 additions & 0 deletions src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import "flwr/proto/task.proto";
service Fleet {
rpc CreateNode(CreateNodeRequest) returns (CreateNodeResponse) {}
rpc DeleteNode(DeleteNodeRequest) returns (DeleteNodeResponse) {}
rpc Ping(PingRequest) returns (PingResponse) {}

// Retrieve one or more tasks, if possible
//
Expand All @@ -43,6 +44,13 @@ message CreateNodeResponse { Node node = 1; }
message DeleteNodeRequest { Node node = 1; }
message DeleteNodeResponse {}

// Ping messages
message PingRequest {
Node node = 1;
double ping_interval = 2;
}
message PingResponse { bool success = 1; }

// PullTaskIns messages
message PullTaskInsRequest {
Node node = 1;
Expand Down
13 changes: 7 additions & 6 deletions src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ import "flwr/proto/error.proto";
message Task {
Node producer = 1;
Node consumer = 2;
string created_at = 3;
double created_at = 3;
string delivered_at = 4;
string ttl = 5;
repeated string ancestry = 6;
string task_type = 7;
RecordSet recordset = 8;
Error error = 9;
double pushed_at = 5;
double ttl = 6;
repeated string ancestry = 7;
string task_type = 8;
RecordSet recordset = 9;
Error error = 10;
}

message TaskIns {
Expand Down
59 changes: 35 additions & 24 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# ==============================================================================
"""Flower client app."""


import argparse
import sys
import time
from logging import DEBUG, INFO, WARN
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Optional, Tuple, Type, Union

Expand All @@ -38,6 +37,7 @@
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
from flwr.common.message import Error
from flwr.common.object_ref import load_app, validate
from flwr.common.retry_invoker import RetryInvoker, exponential

Expand Down Expand Up @@ -482,32 +482,43 @@ def _load_client_app() -> ClientApp:
# Retrieve context for this run
context = node_state.retrieve_context(run_id=message.metadata.run_id)

# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()
# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
reply_message = message.create_error_reply(
error=Error(code=0, reason="Unknown")
)

# Handle task message
out_message = client_app(message=message, context=context)
# Handle app loading and task message
try:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()

# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
reply_message = client_app(message=message, context=context)
# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, "ClientApp raised an exception", exc_info=ex)

# Legacy grpc-bidi
if transport in ["grpc-bidi", None]:
# Raise exception, crash process
raise ex

# Don't update/change NodeState

# Create error message
# Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
reply_message = message.create_error_reply(
error=Error(code=0, reason=reason)
)

# Send
send(out_message)
log(
INFO,
"[RUN %s, ROUND %s]",
out_message.metadata.run_id,
out_message.metadata.group_id,
)
log(
INFO,
"Sent: %s reply to message %s",
out_message.metadata.message_type,
message.metadata.message_id,
)
send(reply_message)
log(INFO, "Sent reply")

# Unregister node
if delete_node is not None:
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def train(message: Message, context: Context) -> Message:
>>> print("ClientApp training running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -143,7 +143,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def evaluate(message: Message, context: Context) -> Message:
>>> print("ClientApp evaluation running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -171,7 +171,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def query(message: Message, context: Context) -> Message:
>>> print("ClientApp query running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError:
>>> print("ClientApp {fn_name} running")
>>> # Create and return an echo reply message
>>> return message.create_reply(
>>> content=message.content(), ttl=""
>>> content=message.content()
>>> )
""",
)
3 changes: 2 additions & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Callable, Iterator, Optional, Tuple, Union, cast

from flwr.common import (
DEFAULT_TTL,
GRPC_MAX_MESSAGE_LENGTH,
ConfigsRecord,
Message,
Expand Down Expand Up @@ -180,7 +181,7 @@ def receive() -> Message:
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type=message_type,
),
content=recordset,
Expand Down
Loading

0 comments on commit 08b4baf

Please sign in to comment.