diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..ad6cb69f305 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: local + hooks: + - id: format-code + name: Format Code + entry: ./dev/format.sh + language: script + # Ensures the script runs from the repository root: + pass_filenames: false + stages: [commit] + + - id: run-tests + name: Run Tests + entry: ./dev/test.sh + language: script + # Ensures the script runs from the repository root: + pass_filenames: false + stages: [commit] diff --git a/datasets/doc/source/conf.py b/datasets/doc/source/conf.py index e5c61b5559c..755147bc9e1 100644 --- a/datasets/doc/source/conf.py +++ b/datasets/doc/source/conf.py @@ -38,7 +38,7 @@ author = "The Flower Authors" # The full version, including alpha/beta/rc tags -release = "0.0.2" +release = "0.1.0" # -- General configuration --------------------------------------------------- diff --git a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst index 9136fea96bf..43f9739987a 100644 --- a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst +++ b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst @@ -102,6 +102,33 @@ Run Linters and Tests $ ./dev/test.sh +Add a pre-commit hook +~~~~~~~~~~~~~~~~~~~~~ + +Developers may integrate a pre-commit hook into their workflow utilizing the `pre-commit `_ library. The pre-commit hook is configured to execute two primary operations: ``./dev/format.sh`` and ``./dev/test.sh`` scripts. + +There are multiple ways developers can use this: + +1. Install the pre-commit hook to your local git directory by simply running: + + :: + + $ pre-commit install + + - Each ``git commit`` will trigger the execution of formatting and linting/test scripts. + - If in a hurry, bypass the hook using ``--no-verify`` with the ``git commit`` command. + :: + + $ git commit --no-verify -m "Add new feature" + +2. For developers who prefer not to install the hook permanently, it is possible to execute a one-time check prior to committing changes by using the following command: + + :: + + $ pre-commit run --all-files + + This executes the formatting and linting checks/tests on all the files without modifying the default behavior of ``git commit``. + Run Github Actions (CI) locally ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py index feea1ee658f..19268ff84ba 100644 --- a/examples/app-pytorch/client_low_level.py +++ b/examples/app-pytorch/client_low_level.py @@ -20,16 +20,16 @@ def hello_world_mod(msg, ctx, call_next) -> Message: @app.train() def train(msg: Message, ctx: Context): print("`train` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content) @app.evaluate() def eval(msg: Message, ctx: Context): print("`evaluate` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content) @app.query() def query(msg: Message, ctx: Context): print("`query` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py index 0c2851e2afe..67c1bce99c5 100644 --- a/examples/app-pytorch/server_custom.py +++ b/examples/app-pytorch/server_custom.py @@ -13,6 +13,7 @@ Message, MessageType, Metrics, + DEFAULT_TTL, ) from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres from flwr.server import Driver, History @@ -89,7 +90,7 @@ def main(driver: Driver, context: Context) -> None: message_type=MessageType.TRAIN, dst_node_id=node_id, group_id=str(server_round), - ttl="", + ttl=DEFAULT_TTL, ) messages.append(message) @@ -102,15 +103,19 @@ def main(driver: Driver, context: Context) -> None: all_replies: List[Message] = [] while True: replies = driver.pull_messages(message_ids=message_ids) - print(f"Got {len(replies)} results") + for res in replies: + print(f"Got 1 {'result' if res.has_content() else 'error'}") all_replies += replies if len(all_replies) == len(message_ids): break + print("Pulling messages...") time.sleep(3) - # Collect correct results + # Filter correct results all_fitres = [ - recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies + recordset_to_fitres(msg.content, keep_input=True) + for msg in all_replies + if msg.has_content() ] print(f"Received {len(all_fitres)} results") @@ -127,16 +132,21 @@ def main(driver: Driver, context: Context) -> None: ) metrics_results.append((fitres.num_examples, fitres.metrics)) - # Aggregate parameters (FedAvg) - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) - parameters = parameters_aggregated + if len(weights_results) > 0: + # Aggregate parameters (FedAvg) + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated - # Aggregate metrics - metrics_aggregated = weighted_average(metrics_results) - history.add_metrics_distributed_fit( - server_round=server_round, metrics=metrics_aggregated - ) - print("Round ", server_round, " metrics: ", metrics_aggregated) + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) + print("Round ", server_round, " metrics: ", metrics_aggregated) + else: + print( + f"Round {server_round} got {len(weights_results)} results. Skipping aggregation..." + ) # Slow down the start of the next round time.sleep(sleep_time) diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py index 560babac1b9..7ab79a4a04c 100644 --- a/examples/app-pytorch/server_low_level.py +++ b/examples/app-pytorch/server_low_level.py @@ -3,7 +3,15 @@ import time import flwr as fl -from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet +from flwr.common import ( + Context, + NDArrays, + Message, + MessageType, + Metrics, + RecordSet, + DEFAULT_TTL, +) from flwr.server import Driver @@ -30,7 +38,7 @@ def main(driver: Driver, context: Context) -> None: message_type=MessageType.TRAIN, dst_node_id=node_id, group_id=str(server_round), - ttl="", + ttl=DEFAULT_TTL, ) messages.append(message) diff --git a/examples/llm-flowertune/requirements.txt b/examples/llm-flowertune/requirements.txt index c7ff57b403f..196531c99b9 100644 --- a/examples/llm-flowertune/requirements.txt +++ b/examples/llm-flowertune/requirements.txt @@ -6,3 +6,4 @@ bitsandbytes==0.41.3 scipy==1.11.2 peft==0.4.0 fschat[model_worker,webui]==0.2.35 +transformers==4.38.1 diff --git a/pyproject.toml b/pyproject.toml index dc8b293bc88..3c211e9cf8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,6 +127,7 @@ check-wheel-contents = "==0.4.0" GitPython = "==3.1.32" PyGithub = "==2.1.1" licensecheck = "==2024" +pre-commit = "==3.5.0" [tool.isort] line_length = 88 diff --git a/src/proto/flwr/proto/fleet.proto b/src/proto/flwr/proto/fleet.proto index c900a3b1148..fa65f3ee9fe 100644 --- a/src/proto/flwr/proto/fleet.proto +++ b/src/proto/flwr/proto/fleet.proto @@ -23,6 +23,7 @@ import "flwr/proto/task.proto"; service Fleet { rpc CreateNode(CreateNodeRequest) returns (CreateNodeResponse) {} rpc DeleteNode(DeleteNodeRequest) returns (DeleteNodeResponse) {} + rpc Ping(PingRequest) returns (PingResponse) {} // Retrieve one or more tasks, if possible // @@ -43,6 +44,13 @@ message CreateNodeResponse { Node node = 1; } message DeleteNodeRequest { Node node = 1; } message DeleteNodeResponse {} +// Ping messages +message PingRequest { + Node node = 1; + double ping_interval = 2; +} +message PingResponse { bool success = 1; } + // PullTaskIns messages message PullTaskInsRequest { Node node = 1; diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 423df76f133..cf77d110aca 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -25,13 +25,14 @@ import "flwr/proto/error.proto"; message Task { Node producer = 1; Node consumer = 2; - string created_at = 3; + double created_at = 3; string delivered_at = 4; - string ttl = 5; - repeated string ancestry = 6; - string task_type = 7; - RecordSet recordset = 8; - Error error = 9; + double pushed_at = 5; + double ttl = 6; + repeated string ancestry = 7; + string task_type = 8; + RecordSet recordset = 9; + Error error = 10; } message TaskIns { diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index c8287afc0fd..d4bd8e2e39e 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -14,11 +14,10 @@ # ============================================================================== """Flower client app.""" - import argparse import sys import time -from logging import DEBUG, INFO, WARN +from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path from typing import Callable, ContextManager, Optional, Tuple, Type, Union @@ -38,6 +37,7 @@ ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature +from flwr.common.message import Error from flwr.common.object_ref import load_app, validate from flwr.common.retry_invoker import RetryInvoker, exponential @@ -482,32 +482,43 @@ def _load_client_app() -> ClientApp: # Retrieve context for this run context = node_state.retrieve_context(run_id=message.metadata.run_id) - # Load ClientApp instance - client_app: ClientApp = load_client_app_fn() + # Create an error reply message that will never be used to prevent + # the used-before-assignment linting error + reply_message = message.create_error_reply( + error=Error(code=0, reason="Unknown") + ) - # Handle task message - out_message = client_app(message=message, context=context) + # Handle app loading and task message + try: + # Load ClientApp instance + client_app: ClientApp = load_client_app_fn() - # Update node state - node_state.update_context( - run_id=message.metadata.run_id, - context=context, - ) + reply_message = client_app(message=message, context=context) + # Update node state + node_state.update_context( + run_id=message.metadata.run_id, + context=context, + ) + except Exception as ex: # pylint: disable=broad-exception-caught + log(ERROR, "ClientApp raised an exception", exc_info=ex) + + # Legacy grpc-bidi + if transport in ["grpc-bidi", None]: + # Raise exception, crash process + raise ex + + # Don't update/change NodeState + + # Create error message + # Reason example: ":<'division by zero'>" + reason = str(type(ex)) + ":<'" + str(ex) + "'>" + reply_message = message.create_error_reply( + error=Error(code=0, reason=reason) + ) # Send - send(out_message) - log( - INFO, - "[RUN %s, ROUND %s]", - out_message.metadata.run_id, - out_message.metadata.group_id, - ) - log( - INFO, - "Sent: %s reply to message %s", - out_message.metadata.message_type, - message.metadata.message_id, - ) + send(reply_message) + log(INFO, "Sent reply") # Unregister node if delete_node is not None: diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index ad7a0132699..0b56219807c 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -115,7 +115,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def train(message: Message, context: Context) -> Message: >>> print("ClientApp training running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content()) """ def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: @@ -143,7 +143,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def evaluate(message: Message, context: Context) -> Message: >>> print("ClientApp evaluation running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content()) """ def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: @@ -171,7 +171,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def query(message: Message, context: Context) -> Message: >>> print("ClientApp query running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content()) """ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: @@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError: >>> print("ClientApp {fn_name} running") >>> # Create and return an echo reply message >>> return message.create_reply( - >>> content=message.content(), ttl="" + >>> content=message.content() >>> ) """, ) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 163a58542c9..4431b53d259 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -23,6 +23,7 @@ from typing import Callable, Iterator, Optional, Tuple, Union, cast from flwr.common import ( + DEFAULT_TTL, GRPC_MAX_MESSAGE_LENGTH, ConfigsRecord, Message, @@ -180,7 +181,7 @@ def receive() -> Message: dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=message_type, ), content=recordset, diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index b7737f511a2..061e7d4377a 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,7 +23,7 @@ import grpc -from flwr.common import ConfigsRecord, Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, ConfigsRecord, Message, Metadata, RecordSet from flwr.common import recordset_compat as compat from flwr.common.constant import MessageTypeLegacy from flwr.common.retry_invoker import RetryInvoker, exponential @@ -50,7 +50,7 @@ dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=compat.getpropertiesres_to_recordset( @@ -65,7 +65,7 @@ dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type="reconnect", ), content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 9a5d70b1ac4..e5acbe0cc9d 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -81,7 +81,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: reason = cast(int, disconnect_msg.disconnect_res.reason) recordset = RecordSet() recordset.configs_records["config"] = ConfigsRecord({"reason": reason}) - out_message = message.create_reply(recordset, ttl="") + out_message = message.create_reply(recordset) # Return TaskRes and sleep duration return out_message, sleep_duration @@ -143,7 +143,7 @@ def handle_legacy_message_from_msgtype( raise ValueError(f"Invalid message type: {message_type}") # Return Message - return message.create_reply(out_recordset, ttl="") + return message.create_reply(out_recordset) def _reconnect( @@ -172,6 +172,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) -> and out_meta.reply_to_message == in_meta.message_id and out_meta.group_id == in_meta.group_id and out_meta.message_type == in_meta.message_type + and out_meta.created_at > in_meta.created_at ): return True return False diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index eaf16f7dc99..5244951c8a4 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -15,6 +15,7 @@ """Client-side message handler tests.""" +import time import unittest import uuid from copy import copy @@ -23,6 +24,7 @@ from flwr.client import Client from flwr.client.typing import ClientFn from flwr.common import ( + DEFAULT_TTL, Code, Context, EvaluateIns, @@ -131,7 +133,7 @@ def test_client_without_get_properties() -> None: src_node_id=0, dst_node_id=1123, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=recordset, @@ -161,14 +163,25 @@ def test_client_without_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl="", + ttl=actual_msg.metadata.ttl, # computed based on [message].create_reply() message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, ) assert actual_msg.content == expected_msg.content - assert actual_msg.metadata == expected_msg.metadata + # metadata.created_at will differ so let's exclude it from checks + attrs = actual_msg.metadata.__annotations__ + attrs_keys = list(attrs.keys()) + attrs_keys.remove("_created_at") + # metadata.created_at will differ so let's exclude it from checks + for attr in attrs_keys: + assert getattr(actual_msg.metadata, attr) == getattr( + expected_msg.metadata, attr + ) + + # Ensure the message created last has a higher timestamp + assert actual_msg.metadata.created_at < expected_msg.metadata.created_at def test_client_with_get_properties() -> None: @@ -184,7 +197,7 @@ def test_client_with_get_properties() -> None: src_node_id=0, dst_node_id=1123, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=recordset, @@ -214,14 +227,24 @@ def test_client_with_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl="", + ttl=actual_msg.metadata.ttl, # computed based on [message].create_reply() message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, ) assert actual_msg.content == expected_msg.content - assert actual_msg.metadata == expected_msg.metadata + attrs = actual_msg.metadata.__annotations__ + attrs_keys = list(attrs.keys()) + attrs_keys.remove("_created_at") + # metadata.created_at will differ so let's exclude it from checks + for attr in attrs_keys: + assert getattr(actual_msg.metadata, attr) == getattr( + expected_msg.metadata, attr + ) + + # Ensure the message created last has a higher timestamp + assert actual_msg.metadata.created_at < expected_msg.metadata.created_at class TestMessageValidation(unittest.TestCase): @@ -237,9 +260,14 @@ def setUp(self) -> None: dst_node_id=20, reply_to_message="", group_id="group1", - ttl="60", + ttl=DEFAULT_TTL, message_type="mock", ) + # We need to set created_at in this way + # since this `self.in_metadata` is used for tests + # without it ever being part of a Message + self.in_metadata.created_at = time.time() + self.valid_out_metadata = Metadata( run_id=123, message_id="", @@ -247,7 +275,7 @@ def setUp(self) -> None: dst_node_id=10, reply_to_message="qwerty", group_id="group1", - ttl="60", + ttl=DEFAULT_TTL, message_type="mock", ) self.common_content = RecordSet() @@ -280,6 +308,10 @@ def test_invalid_message_run_id(self) -> None: value = 999 elif isinstance(value, str): value = "999" + elif isinstance(value, float): + if attr == "_created_at": + # make it be in 1h the past + value = value - 3600 setattr(invalid_metadata, attr, value) # Add to list invalid_metadata_list.append(invalid_metadata) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 989d5f6e136..5b196ad8432 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -187,7 +187,7 @@ def secaggplus_mod( # Return message out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False) - return msg.create_reply(out_content, ttl="") + return msg.create_reply(out_content) def check_stage(current_stage: str, configs: ConfigsRecord) -> None: diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index db5ed67c02a..36844a2983a 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -19,7 +19,14 @@ from typing import Callable, Dict, List from flwr.client.mod import make_ffn -from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet +from flwr.common import ( + DEFAULT_TTL, + ConfigsRecord, + Context, + Message, + Metadata, + RecordSet, +) from flwr.common.constant import MessageType from flwr.common.secure_aggregation.secaggplus_constants import ( RECORD_KEY_CONFIGS, @@ -38,7 +45,7 @@ def get_test_handler( """.""" def empty_ffn(_msg: Message, _2: Context) -> Message: - return _msg.create_reply(RecordSet(), ttl="") + return _msg.create_reply(RecordSet()) app = make_ffn(empty_ffn, [secaggplus_mod]) @@ -51,7 +58,7 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: dst_node_id=123, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageType.TRAIN, ), content=RecordSet( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index e588b8b53b3..4676a2c02c4 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -20,6 +20,7 @@ from flwr.client.typing import ClientAppCallable, Mod from flwr.common import ( + DEFAULT_TTL, ConfigsRecord, Context, Message, @@ -84,7 +85,7 @@ def _get_dummy_flower_message() -> Message: src_node_id=0, dst_node_id=0, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type="mock", ), ) diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 9f9ff7ebc68..2fb98c82dd6 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -22,6 +22,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH from .logger import configure as configure from .logger import log as log +from .message import DEFAULT_TTL from .message import Error as Error from .message import Message as Message from .message import Metadata as Metadata @@ -87,6 +88,7 @@ "Message", "MessageType", "MessageTypeLegacy", + "DEFAULT_TTL", "Metadata", "Metrics", "MetricsAggregationFn", diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 88cf750f1a9..7707f3c72de 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -16,10 +16,13 @@ from __future__ import annotations +import time from dataclasses import dataclass from .record import RecordSet +DEFAULT_TTL = 3600 + @dataclass class Metadata: # pylint: disable=too-many-instance-attributes @@ -40,8 +43,8 @@ class Metadata: # pylint: disable=too-many-instance-attributes group_id : str An identifier for grouping messages. In some settings, this is used as the FL round. - ttl : str - Time-to-live for this message. + ttl : float + Time-to-live for this message in seconds. message_type : str A string that encodes the action to be executed on the receiving end. @@ -57,9 +60,10 @@ class Metadata: # pylint: disable=too-many-instance-attributes _dst_node_id: int _reply_to_message: str _group_id: str - _ttl: str + _ttl: float _message_type: str _partition_id: int | None + _created_at: float # Unix timestamp (in seconds) to be set upon message creation def __init__( # pylint: disable=too-many-arguments self, @@ -69,7 +73,7 @@ def __init__( # pylint: disable=too-many-arguments dst_node_id: int, reply_to_message: str, group_id: str, - ttl: str, + ttl: float, message_type: str, partition_id: int | None = None, ) -> None: @@ -124,12 +128,22 @@ def group_id(self, value: str) -> None: self._group_id = value @property - def ttl(self) -> str: + def created_at(self) -> float: + """Unix timestamp when the message was created.""" + return self._created_at + + @created_at.setter + def created_at(self, value: float) -> None: + """Set creation timestamp for this messages.""" + self._created_at = value + + @property + def ttl(self) -> float: """Time-to-live for this message.""" return self._ttl @ttl.setter - def ttl(self, value: str) -> None: + def ttl(self, value: float) -> None: """Set ttl.""" self._ttl = value @@ -212,6 +226,9 @@ def __init__( ) -> None: self._metadata = metadata + # Set message creation timestamp + self._metadata.created_at = time.time() + if not (content is None) ^ (error is None): raise ValueError("Either `content` or `error` must be set, but not both.") @@ -266,7 +283,7 @@ def has_error(self) -> bool: """Return True if message has an error, else False.""" return self._error is not None - def _create_reply_metadata(self, ttl: str) -> Metadata: + def _create_reply_metadata(self, ttl: float) -> Metadata: """Construct metadata for a reply message.""" return Metadata( run_id=self.metadata.run_id, @@ -280,25 +297,36 @@ def _create_reply_metadata(self, ttl: str) -> Metadata: partition_id=self.metadata.partition_id, ) - def create_error_reply( - self, - error: Error, - ttl: str, - ) -> Message: + def create_error_reply(self, error: Error, ttl: float | None = None) -> Message: """Construct a reply message indicating an error happened. Parameters ---------- error : Error The error that was encountered. - ttl : str - Time-to-live for this message. + ttl : Optional[float] (default: None) + Time-to-live for this message in seconds. If unset, it will be set based + on the remaining time for the received message before it expires. This + follows the equation: + + ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at) """ + # If no TTL passed, use default for message creation (will update after + # message creation) + ttl_ = DEFAULT_TTL if ttl is None else ttl # Create reply with error - message = Message(metadata=self._create_reply_metadata(ttl), error=error) + message = Message(metadata=self._create_reply_metadata(ttl_), error=error) + + if ttl is None: + # Set TTL equal to the remaining time for the received message to expire + ttl = self.metadata.ttl - ( + message.metadata.created_at - self.metadata.created_at + ) + message.metadata.ttl = ttl + return message - def create_reply(self, content: RecordSet, ttl: str) -> Message: + def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message: """Create a reply to this message with specified content and TTL. The method generates a new `Message` as a reply to this message. @@ -309,15 +337,32 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message: ---------- content : RecordSet The content for the reply message. - ttl : str - Time-to-live for this message. + ttl : Optional[float] (default: None) + Time-to-live for this message in seconds. If unset, it will be set based + on the remaining time for the received message before it expires. This + follows the equation: + + ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at) Returns ------- Message A new `Message` instance representing the reply. """ - return Message( - metadata=self._create_reply_metadata(ttl), + # If no TTL passed, use default for message creation (will update after + # message creation) + ttl_ = DEFAULT_TTL if ttl is None else ttl + + message = Message( + metadata=self._create_reply_metadata(ttl_), content=content, ) + + if ttl is None: + # Set TTL equal to the remaining time for the received message to expire + ttl = self.metadata.ttl - ( + message.metadata.created_at - self.metadata.created_at + ) + message.metadata.ttl = ttl + + return message diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py index ba628bb3235..1a5da051735 100644 --- a/src/py/flwr/common/message_test.py +++ b/src/py/flwr/common/message_test.py @@ -14,9 +14,9 @@ # ============================================================================== """Message tests.""" - +import time from contextlib import ExitStack -from typing import Any, Callable +from typing import Any, Callable, Optional import pytest @@ -62,24 +62,32 @@ def test_message_creation( if context: stack.enter_context(context) - _ = Message( + current_time = time.time() + message = Message( metadata=metadata, content=None if content_fn is None else content_fn(maker), error=None if error_fn is None else error_fn(0), ) + assert message.metadata.created_at > current_time + assert message.metadata.created_at < time.time() + -def create_message_with_content() -> Message: +def create_message_with_content(ttl: Optional[float] = None) -> Message: """Create a Message with content.""" maker = RecordMaker(state=2) metadata = maker.metadata() + if ttl: + metadata.ttl = ttl return Message(metadata=metadata, content=RecordSet()) -def create_message_with_error() -> Message: +def create_message_with_error(ttl: Optional[float] = None) -> Message: """Create a Message with error.""" maker = RecordMaker(state=2) metadata = maker.metadata() + if ttl: + metadata.ttl = ttl return Message(metadata=metadata, error=Error(code=1)) @@ -107,3 +115,45 @@ def test_altering_message( message.error = Error(code=123) if message.has_error(): message.content = RecordSet() + + +@pytest.mark.parametrize( + "message_creation_fn,ttl,reply_ttl", + [ + (create_message_with_content, 1e6, None), + (create_message_with_error, 1e6, None), + (create_message_with_content, 1e6, 3600), + (create_message_with_error, 1e6, 3600), + ], +) +def test_create_reply( + message_creation_fn: Callable[ + [float], + Message, + ], + ttl: float, + reply_ttl: Optional[float], +) -> None: + """Test reply creation from message.""" + message: Message = message_creation_fn(ttl) + + time.sleep(0.1) + + if message.has_error(): + dummy_error = Error(code=0, reason="it crashed") + reply_message = message.create_error_reply(dummy_error, ttl=reply_ttl) + else: + reply_message = message.create_reply(content=RecordSet(), ttl=reply_ttl) + + # Ensure reply has a higher timestamp + assert message.metadata.created_at < reply_message.metadata.created_at + if reply_ttl: + # Ensure the TTL is the one specify upon reply creation + assert reply_message.metadata.ttl == reply_ttl + else: + # Ensure reply ttl is lower (since it uses remaining time left) + assert message.metadata.ttl > reply_message.metadata.ttl + + assert message.metadata.src_node_id == reply_message.metadata.dst_node_id + assert message.metadata.dst_node_id == reply_message.metadata.src_node_id + assert reply_message.metadata.reply_to_message == message.metadata.message_id diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 6c7a077d2f9..84932b806af 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -575,6 +575,7 @@ def message_to_taskins(message: Message) -> TaskIns: task=Task( producer=Node(node_id=0, anonymous=True), # Assume driver node consumer=Node(node_id=md.dst_node_id, anonymous=False), + created_at=md.created_at, ttl=md.ttl, ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], task_type=md.message_type, @@ -601,7 +602,7 @@ def message_from_taskins(taskins: TaskIns) -> Message: ) # Construct Message - return Message( + message = Message( metadata=metadata, content=( recordset_from_proto(taskins.task.recordset) @@ -614,6 +615,8 @@ def message_from_taskins(taskins: TaskIns) -> Message: else None ), ) + message.metadata.created_at = taskins.task.created_at + return message def message_to_taskres(message: Message) -> TaskRes: @@ -626,6 +629,7 @@ def message_to_taskres(message: Message) -> TaskRes: task=Task( producer=Node(node_id=md.src_node_id, anonymous=False), consumer=Node(node_id=0, anonymous=True), # Assume driver node + created_at=md.created_at, ttl=md.ttl, ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], task_type=md.message_type, @@ -652,7 +656,7 @@ def message_from_taskres(taskres: TaskRes) -> Message: ) # Construct the Message - return Message( + message = Message( metadata=metadata, content=( recordset_from_proto(taskres.task.recordset) @@ -665,3 +669,5 @@ def message_from_taskres(taskres: TaskRes) -> Message: else None ), ) + message.metadata.created_at = taskres.task.created_at + return message diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 8596e5d2f33..fc12ce95328 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -219,7 +219,7 @@ def metadata(self) -> Metadata: src_node_id=self.rng.randint(0, 1 << 63), dst_node_id=self.rng.randint(0, 1 << 63), reply_to_message=self.get_str(64), - ttl=self.get_str(10), + ttl=self.rng.randint(1, 1 << 30), message_type=self.get_str(10), ) diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index e8443c296f0..546987f1c80 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x02\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -33,18 +33,22 @@ _globals['_DELETENODEREQUEST']._serialized_end=210 _globals['_DELETENODERESPONSE']._serialized_start=212 _globals['_DELETENODERESPONSE']._serialized_end=232 - _globals['_PULLTASKINSREQUEST']._serialized_start=234 - _globals['_PULLTASKINSREQUEST']._serialized_end=304 - _globals['_PULLTASKINSRESPONSE']._serialized_start=306 - _globals['_PULLTASKINSRESPONSE']._serialized_end=413 - _globals['_PUSHTASKRESREQUEST']._serialized_start=415 - _globals['_PUSHTASKRESREQUEST']._serialized_end=479 - _globals['_PUSHTASKRESRESPONSE']._serialized_start=482 - _globals['_PUSHTASKRESRESPONSE']._serialized_end=656 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=610 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=656 - _globals['_RECONNECT']._serialized_start=658 - _globals['_RECONNECT']._serialized_end=688 - _globals['_FLEET']._serialized_start=691 - _globals['_FLEET']._serialized_end=1020 + _globals['_PINGREQUEST']._serialized_start=234 + _globals['_PINGREQUEST']._serialized_end=302 + _globals['_PINGRESPONSE']._serialized_start=304 + _globals['_PINGRESPONSE']._serialized_end=335 + _globals['_PULLTASKINSREQUEST']._serialized_start=337 + _globals['_PULLTASKINSREQUEST']._serialized_end=407 + _globals['_PULLTASKINSRESPONSE']._serialized_start=409 + _globals['_PULLTASKINSRESPONSE']._serialized_end=516 + _globals['_PUSHTASKRESREQUEST']._serialized_start=518 + _globals['_PUSHTASKRESREQUEST']._serialized_end=582 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=585 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=759 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=713 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=759 + _globals['_RECONNECT']._serialized_start=761 + _globals['_RECONNECT']._serialized_end=791 + _globals['_FLEET']._serialized_start=794 + _globals['_FLEET']._serialized_end=1184 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fleet_pb2.pyi b/src/py/flwr/proto/fleet_pb2.pyi index 86bc358858d..e5c5b736646 100644 --- a/src/py/flwr/proto/fleet_pb2.pyi +++ b/src/py/flwr/proto/fleet_pb2.pyi @@ -53,6 +53,34 @@ class DeleteNodeResponse(google.protobuf.message.Message): ) -> None: ... global___DeleteNodeResponse = DeleteNodeResponse +class PingRequest(google.protobuf.message.Message): + """Ping messages""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int + PING_INTERVAL_FIELD_NUMBER: builtins.int + @property + def node(self) -> flwr.proto.node_pb2.Node: ... + ping_interval: builtins.float + def __init__(self, + *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., + ping_interval: builtins.float = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","ping_interval",b"ping_interval"]) -> None: ... +global___PingRequest = PingRequest + +class PingResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__(self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ... +global___PingResponse = PingResponse + class PullTaskInsRequest(google.protobuf.message.Message): """PullTaskIns messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/src/py/flwr/proto/fleet_pb2_grpc.py b/src/py/flwr/proto/fleet_pb2_grpc.py index 2b53ec43e85..c31a4ec73f0 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.py +++ b/src/py/flwr/proto/fleet_pb2_grpc.py @@ -24,6 +24,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.FromString, ) + self.Ping = channel.unary_unary( + '/flwr.proto.Fleet/Ping', + request_serializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString, + ) self.PullTaskIns = channel.unary_unary( '/flwr.proto.Fleet/PullTaskIns', request_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString, @@ -51,6 +56,12 @@ def DeleteNode(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def Ping(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def PullTaskIns(self, request, context): """Retrieve one or more tasks, if possible @@ -82,6 +93,11 @@ def add_FleetServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString, response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString, ), + 'Ping': grpc.unary_unary_rpc_method_handler( + servicer.Ping, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.SerializeToString, + ), 'PullTaskIns': grpc.unary_unary_rpc_method_handler( servicer.PullTaskIns, request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString, @@ -136,6 +152,23 @@ def DeleteNode(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod + def Ping(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/Ping', + flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString, + flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def PullTaskIns(request, target, diff --git a/src/py/flwr/proto/fleet_pb2_grpc.pyi b/src/py/flwr/proto/fleet_pb2_grpc.pyi index cfa83f73743..33ba9440793 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.pyi +++ b/src/py/flwr/proto/fleet_pb2_grpc.pyi @@ -16,6 +16,10 @@ class FleetStub: flwr.proto.fleet_pb2.DeleteNodeRequest, flwr.proto.fleet_pb2.DeleteNodeResponse] + Ping: grpc.UnaryUnaryMultiCallable[ + flwr.proto.fleet_pb2.PingRequest, + flwr.proto.fleet_pb2.PingResponse] + PullTaskIns: grpc.UnaryUnaryMultiCallable[ flwr.proto.fleet_pb2.PullTaskInsRequest, flwr.proto.fleet_pb2.PullTaskInsResponse] @@ -46,6 +50,12 @@ class FleetServicer(metaclass=abc.ABCMeta): context: grpc.ServicerContext, ) -> flwr.proto.fleet_pb2.DeleteNodeResponse: ... + @abc.abstractmethod + def Ping(self, + request: flwr.proto.fleet_pb2.PingRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.fleet_pb2.PingResponse: ... + @abc.abstractmethod def PullTaskIns(self, request: flwr.proto.fleet_pb2.PullTaskInsRequest, diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 4d5f863e88d..5f6e9e7be58 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -26,9 +26,9 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals['_TASK']._serialized_start=141 - _globals['_TASK']._serialized_end=387 - _globals['_TASKINS']._serialized_start=389 - _globals['_TASKINS']._serialized_end=481 - _globals['_TASKRES']._serialized_start=483 - _globals['_TASKRES']._serialized_end=575 + _globals['_TASK']._serialized_end=406 + _globals['_TASKINS']._serialized_start=408 + _globals['_TASKINS']._serialized_end=500 + _globals['_TASKRES']._serialized_start=502 + _globals['_TASKRES']._serialized_end=594 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index b9c10139cfb..455791ac9e6 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -20,6 +20,7 @@ class Task(google.protobuf.message.Message): CONSUMER_FIELD_NUMBER: builtins.int CREATED_AT_FIELD_NUMBER: builtins.int DELIVERED_AT_FIELD_NUMBER: builtins.int + PUSHED_AT_FIELD_NUMBER: builtins.int TTL_FIELD_NUMBER: builtins.int ANCESTRY_FIELD_NUMBER: builtins.int TASK_TYPE_FIELD_NUMBER: builtins.int @@ -29,9 +30,10 @@ class Task(google.protobuf.message.Message): def producer(self) -> flwr.proto.node_pb2.Node: ... @property def consumer(self) -> flwr.proto.node_pb2.Node: ... - created_at: typing.Text + created_at: builtins.float delivered_at: typing.Text - ttl: typing.Text + pushed_at: builtins.float + ttl: builtins.float @property def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... task_type: typing.Text @@ -43,16 +45,17 @@ class Task(google.protobuf.message.Message): *, producer: typing.Optional[flwr.proto.node_pb2.Node] = ..., consumer: typing.Optional[flwr.proto.node_pb2.Node] = ..., - created_at: typing.Text = ..., + created_at: builtins.float = ..., delivered_at: typing.Text = ..., - ttl: typing.Text = ..., + pushed_at: builtins.float = ..., + ttl: builtins.float = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., error: typing.Optional[flwr.proto.error_pb2.Error] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","pushed_at",b"pushed_at","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ... global___Task = Task class TaskIns(google.protobuf.message.Message): diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index 84c67149fad..58341c7bb8f 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -19,7 +19,7 @@ from typing import List, Optional from flwr import common -from flwr.common import MessageType, MessageTypeLegacy, RecordSet +from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet from flwr.common import recordset_compat as compat from flwr.common import serde from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 @@ -129,8 +129,16 @@ def _send_receive_recordset( ), task_type=task_type, recordset=serde.recordset_to_proto(recordset), + ttl=DEFAULT_TTL, ), ) + + # This would normally be recorded upon common.Message creation + # but this compatibility stack doesn't create Messages, + # so we need to inject `created_at` manually (needed for + # taskins validation by server.utils.validator) + task_ins.task.created_at = time.time() + push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 task_ins_list=[task_ins] ) @@ -162,8 +170,24 @@ def _send_receive_recordset( ) if len(task_res_list) == 1: task_res = task_res_list[0] + + # This will raise an Exception if task_res carries an `error` + validate_task_res(task_res=task_res) + return serde.recordset_from_proto(task_res.task.recordset) if timeout is not None and time.time() > start_time + timeout: raise RuntimeError("Timeout reached") time.sleep(SLEEP_TIME) + + +def validate_task_res( + task_res: task_pb2.TaskRes, # pylint: disable=E1101 +) -> None: + """Validate if a TaskRes is empty or not.""" + if not task_res.HasField("task"): + raise ValueError("Invalid TaskRes, field `task` missing") + if task_res.task.HasField("error"): + raise ValueError("Exception during client-side task execution") + if not task_res.task.HasField("recordset"): + raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing") diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index 3494049c106..57b35fc61a3 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -38,9 +38,14 @@ Properties, Status, ) -from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 - -from .driver_client_proxy import DriverClientProxy +from flwr.proto import ( # pylint: disable=E0611 + driver_pb2, + error_pb2, + node_pb2, + recordset_pb2, + task_pb2, +) +from flwr.server.compat.driver_client_proxy import DriverClientProxy, validate_task_res MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") @@ -243,3 +248,77 @@ def test_evaluate(self) -> None: # Assert assert 0.0 == evaluate_res.loss assert 0 == evaluate_res.num_examples + + def test_validate_task_res_valid(self) -> None: + """Test valid TaskRes.""" + metrics_record = recordset_pb2.MetricsRecord( # pylint: disable=E1101 + data={ + "loss": recordset_pb2.MetricsRecordValue( # pylint: disable=E1101 + double=1.0 + ) + } + ) + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + recordset=recordset_pb2.RecordSet( # pylint: disable=E1101 + parameters={}, + metrics={"loss": metrics_record}, + configs={}, + ) + ), + ) + + # Execute & assert + try: + validate_task_res(task_res=task_res) + except ValueError: + self.fail() + + def test_validate_task_res_missing_task(self) -> None: + """Test invalid TaskRes (missing task).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) + + def test_validate_task_res_missing_recordset(self) -> None: + """Test invalid TaskRes (missing recordset).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task(), # pylint: disable=E1101 + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) + + def test_validate_task_res_missing_content(self) -> None: + """Test invalid TaskRes (missing content).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + error=error_pb2.Error( # pylint: disable=E1101 + code=0, + reason="Some reason", + ) + ), + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 0098e0ce97c..afebd90ea26 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -18,7 +18,7 @@ import time from typing import Iterable, List, Optional, Tuple -from flwr.common import Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -81,6 +81,7 @@ def _check_message(self, message: Message) -> None: and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" + and message.metadata.ttl > 0 ): raise ValueError(f"Invalid message: {message}") @@ -90,7 +91,7 @@ def create_message( # pylint: disable=too-many-arguments message_type: str, dst_node_id: int, group_id: str, - ttl: str, + ttl: float = DEFAULT_TTL, ) -> Message: """Create a new message with specified parameters. @@ -110,10 +111,10 @@ def create_message( # pylint: disable=too-many-arguments group_id : str The ID of the group to which this message is associated. In some settings, this is used as the FL round. - ttl : str + ttl : float (default: common.DEFAULT_TTL) Time-to-live for the round trip of this message, i.e., the time from sending - this message to receiving a reply. It specifies the duration for which the - message and its potential reply are considered valid. + this message to receiving a reply. It specifies in seconds the duration for + which the message and its potential reply are considered valid. Returns ------- diff --git a/src/py/flwr/server/driver/driver_test.py b/src/py/flwr/server/driver/driver_test.py index 5136f4f9021..3f1cd552250 100644 --- a/src/py/flwr/server/driver/driver_test.py +++ b/src/py/flwr/server/driver/driver_test.py @@ -19,7 +19,7 @@ import unittest from unittest.mock import Mock, patch -from flwr.common import RecordSet +from flwr.common import DEFAULT_TTL, RecordSet from flwr.common.message import Error from flwr.common.serde import error_to_proto, recordset_to_proto from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 @@ -99,7 +99,8 @@ def test_push_messages_valid(self) -> None: mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response msgs = [ - self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) + for _ in range(2) ] # Execute @@ -121,7 +122,8 @@ def test_push_messages_invalid(self) -> None: mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response msgs = [ - self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) + for _ in range(2) ] # Use invalid run_id msgs[1].metadata._run_id += 1 # pylint: disable=protected-access @@ -170,7 +172,7 @@ def test_send_and_receive_messages_complete(self) -> None: task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) self.mock_grpc_driver.pull_task_res.return_value = mock_response - msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute ret_msgs = list(self.driver.send_and_receive(msgs)) @@ -187,7 +189,7 @@ def test_send_and_receive_messages_timeout(self) -> None: self.mock_grpc_driver.push_task_ins.return_value = mock_response mock_response = Mock(task_res_list=[]) self.mock_grpc_driver.pull_task_res.return_value = mock_response - msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 59e51ef52d8..c5e8d055b70 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -15,6 +15,7 @@ """Driver API servicer.""" +import time from logging import DEBUG, INFO from typing import List, Optional, Set from uuid import UUID @@ -72,6 +73,11 @@ def PushTaskIns( """Push a set of TaskIns.""" log(DEBUG, "DriverServicer.PushTaskIns") + # Set pushed_at (timestamp in seconds) + pushed_at = time.time() + for task_ins in request.task_ins_list: + task_ins.task.pushed_at = pushed_at + # Validate request _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty") for task_ins in request.task_ins_list: diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index 27847447737..eb8dd800ea3 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -15,7 +15,7 @@ """Fleet API gRPC request-response servicer.""" -from logging import INFO +from logging import DEBUG, INFO import grpc @@ -26,6 +26,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + PingRequest, + PingResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, @@ -61,6 +63,14 @@ def DeleteNode( state=self.state_factory.state(), ) + def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse: + """.""" + log(DEBUG, "FleetServicer.Ping") + return message_handler.ping( + request=request, + state=self.state_factory.state(), + ) + def PullTaskIns( self, request: PullTaskInsRequest, context: grpc.ServicerContext ) -> PullTaskInsResponse: diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index c99a7854d53..d4e63a8f2d4 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -15,6 +15,7 @@ """Fleet API message handlers.""" +import time from typing import List, Optional from uuid import UUID @@ -23,6 +24,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + PingRequest, + PingResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, @@ -55,6 +58,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse: return DeleteNodeResponse() +def ping( + request: PingRequest, # pylint: disable=unused-argument + state: State, # pylint: disable=unused-argument +) -> PingResponse: + """.""" + return PingResponse(success=True) + + def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse: """Pull TaskIns handler.""" # Get node_id if client node is not anonymous @@ -77,6 +88,9 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo task_res: TaskRes = request.task_res_list[0] # pylint: enable=no-member + # Set pushed_at (timestamp in seconds) + task_res.task.pushed_at = time.time() + # Store TaskRes in State task_id: Optional[UUID] = state.store_task_res(task_res=task_res) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 8ef0d54622a..9bede09edf0 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -20,7 +20,7 @@ import ray -from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.client_app import ClientApp from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message @@ -151,7 +151,6 @@ async def process_message( ) await future - # Fetch result ( out_mssg, @@ -160,13 +159,15 @@ async def process_message( return out_mssg, updated_context - except LoadClientAppError as load_ex: + except Exception as ex: log( ERROR, "An exception was raised when processing a message by %s", self.__class__.__name__, ) - raise load_ex + # add actor back into pool + await self.pool.add_actor_back_to_pool(future) + raise ex async def terminate(self) -> None: """Terminate all actors in actor pool.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 2610307bb74..dcac0b81d66 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -25,6 +25,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import ( + DEFAULT_TTL, Config, ConfigsRecord, Context, @@ -111,7 +112,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: src_node_id=0, dst_node_id=0, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), ) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index a693c968d0e..9736ae0fb57 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -14,9 +14,10 @@ # ============================================================================== """Fleet Simulation Engine API.""" - import asyncio import json +import sys +import time import traceback from logging import DEBUG, ERROR, INFO, WARN from typing import Callable, Dict, List, Optional @@ -24,6 +25,7 @@ from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.client.node_state import NodeState from flwr.common.logger import log +from flwr.common.message import Error from flwr.common.object_ref import load_app from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 @@ -59,6 +61,7 @@ async def worker( """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" state = state_factory.state() while True: + out_mssg = None try: task_ins: TaskIns = await queue.get() node_id = task_ins.task.consumer.node_id @@ -82,24 +85,25 @@ async def worker( task_ins.run_id, context=updated_context ) - # Convert to TaskRes - task_res = message_to_taskres(out_mssg) - # Store TaskRes in state - state.store_task_res(task_res) - except asyncio.CancelledError as e: - log(DEBUG, "Async worker: %s", e) + log(DEBUG, "Terminating async worker: %s", e) break - except LoadClientAppError as app_ex: - log(ERROR, "Async worker: %s", app_ex) - log(ERROR, traceback.format_exc()) - raise - + # Exceptions aren't raised but reported as an error message except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, ex) log(ERROR, traceback.format_exc()) - break + reason = str(type(ex)) + ":<'" + str(ex) + "'>" + error = Error(code=0, reason=reason) + out_mssg = message.create_error_reply(error=error) + + finally: + if out_mssg: + # Convert to TaskRes + task_res = message_to_taskres(out_mssg) + # Store TaskRes in state + task_res.task.pushed_at = time.time() + state.store_task_res(task_res) async def add_taskins_to_queue( @@ -218,7 +222,7 @@ async def run( await backend.terminate() -# pylint: disable=too-many-arguments,unused-argument,too-many-locals +# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches def start_vce( backend_name: str, backend_config_json_stream: str, @@ -300,12 +304,14 @@ def backend_fn() -> Backend: """Instantiate a Backend.""" return backend_type(backend_config, work_dir=app_dir) - log(INFO, "client_app_attr = %s", client_app_attr) - # Load ClientApp if needed def _load() -> ClientApp: if client_app_attr: + + if app_dir is not None: + sys.path.insert(0, app_dir) + app: ClientApp = load_app(client_app_attr, LoadClientAppError) if not isinstance(app, ClientApp): @@ -319,13 +325,23 @@ def _load() -> ClientApp: app_fn = _load - asyncio.run( - run( - app_fn, - backend_fn, - nodes_mapping, - state_factory, - node_states, - f_stop, + try: + # Test if ClientApp can be loaded + _ = app_fn() + + # Run main simulation loop + asyncio.run( + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, + ) ) - ) + except LoadClientAppError as loadapp_ex: + f_stop.set() # set termination event + raise loadapp_ex + except Exception as ex: + raise ex diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 8c37399ae29..66c3c21326d 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -17,6 +17,7 @@ import asyncio import threading +import time from itertools import cycle from json import JSONDecodeError from math import pi @@ -26,7 +27,14 @@ from unittest import IsolatedAsyncioTestCase from uuid import UUID -from flwr.common import GetPropertiesIns, Message, MessageTypeLegacy, Metadata +from flwr.client.client_app import LoadClientAppError +from flwr.common import ( + DEFAULT_TTL, + GetPropertiesIns, + Message, + MessageTypeLegacy, + Metadata, +) from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.server.superlink.fleet.vce.vce_api import ( @@ -46,7 +54,6 @@ def terminate_simulation(f_stop: asyncio.Event, sleep_duration: int) -> None: def init_state_factory_nodes_mapping( num_nodes: int, num_messages: int, - erroneous_message: Optional[bool] = False, ) -> Tuple[StateFactory, NodeToPartitionMapping, Dict[UUID, float]]: """Instatiate StateFactory, register nodes and pre-insert messages in the state.""" # Register a state and a run_id in it @@ -61,7 +68,6 @@ def init_state_factory_nodes_mapping( nodes_mapping=nodes_mapping, run_id=run_id, num_messages=num_messages, - erroneous_message=erroneous_message, ) return state_factory, nodes_mapping, expected_results @@ -72,7 +78,6 @@ def register_messages_into_state( nodes_mapping: NodeToPartitionMapping, run_id: int, num_messages: int, - erroneous_message: Optional[bool] = False, ) -> Dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore @@ -97,16 +102,15 @@ def register_messages_into_state( src_node_id=0, dst_node_id=dst_node_id, # indicate destination node reply_to_message="", - ttl="", - message_type=( - "a bad message" - if erroneous_message - else MessageTypeLegacy.GET_PROPERTIES - ), + ttl=DEFAULT_TTL, + message_type=MessageTypeLegacy.GET_PROPERTIES, ), ) # Convert Message to TaskIns taskins = message_to_taskins(message) + # Normally recorded by the driver servicer + # but since we don't have one in this test, we do this manually + taskins.task.pushed_at = time.time() # Instert in state task_id = state.store_task_ins(taskins) if task_id: @@ -190,32 +194,13 @@ def test_erroneous_client_app_attr(self) -> None: state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( num_nodes=num_nodes, num_messages=num_messages ) - with self.assertRaises(RuntimeError): + with self.assertRaises(LoadClientAppError): start_and_shutdown( client_app_attr="totally_fictitious_app:client", state_factory=state_factory, nodes_mapping=nodes_mapping, ) - def test_erroneous_messages(self) -> None: - """Test handling of error in async worker (consumer). - - We register messages which will trigger an error when handling, triggering an - error. - """ - num_messages = 100 - num_nodes = 59 - - state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( - num_nodes=num_nodes, num_messages=num_messages, erroneous_message=True - ) - - with self.assertRaises(RuntimeError): - start_and_shutdown( - state_factory=state_factory, - nodes_mapping=nodes_mapping, - ) - def test_erroneous_backend_config(self) -> None: """Backend Config should be a JSON stream.""" with self.assertRaises(JSONDecodeError): diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ac1ab158e25..6fc57707ac3 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -17,9 +17,9 @@ import os import threading -from datetime import datetime, timedelta +import time from logging import ERROR -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from uuid import UUID, uuid4 from flwr.common import log, now @@ -32,7 +32,8 @@ class InMemoryState(State): """In-memory State implementation.""" def __init__(self) -> None: - self.node_ids: Set[int] = set() + # Map node_id to (online_until, ping_interval) + self.node_ids: Dict[int, Tuple[float, float]] = {} self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} @@ -50,15 +51,11 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id, created_at and ttl + # Create task_id task_id = uuid4() - created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_ins.task_id = str(task_id) - task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = ttl.isoformat() with self.lock: self.task_ins_store[task_id] = task_ins @@ -113,15 +110,11 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id, created_at and ttl + # Create task_id task_id = uuid4() - created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskRes task_res.task_id = str(task_id) - task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = ttl.isoformat() with self.lock: self.task_res_store[task_id] = task_res @@ -194,17 +187,21 @@ def create_node(self) -> int: # Sample a random int64 as node_id node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) - if node_id not in self.node_ids: - self.node_ids.add(node_id) - return node_id + with self.lock: + if node_id not in self.node_ids: + # Default ping interval is 30s + # TODO: change 1e9 to 30s # pylint: disable=W0511 + self.node_ids[node_id] = (time.time() + 1e9, 1e9) + return node_id log(ERROR, "Unexpected node registration failure.") return 0 def delete_node(self, node_id: int) -> None: """Delete a client node.""" - if node_id not in self.node_ids: - raise ValueError(f"Node {node_id} not found") - self.node_ids.remove(node_id) + with self.lock: + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + del self.node_ids[node_id] def get_nodes(self, run_id: int) -> Set[int]: """Return all available client nodes. @@ -214,17 +211,32 @@ def get_nodes(self, run_id: int) -> Set[int]: If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ - if run_id not in self.run_ids: - return set() - return self.node_ids + with self.lock: + if run_id not in self.run_ids: + return set() + current_time = time.time() + return { + node_id + for node_id, (online_until, _) in self.node_ids.items() + if online_until > current_time + } def create_run(self) -> int: """Create one run.""" # Sample a random int64 as run_id - run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + with self.lock: + run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) - if run_id not in self.run_ids: - self.run_ids.add(run_id) - return run_id + if run_id not in self.run_ids: + self.run_ids.add(run_id) + return run_id log(ERROR, "Unexpected run creation failure.") return 0 + + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: + """Acknowledge a ping received from a node, serving as a heartbeat.""" + with self.lock: + if node_id in self.node_ids: + self.node_ids[node_id] = (time.time() + ping_interval, ping_interval) + return True + return False diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 224c16cdf01..6996d51d2a9 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -18,7 +18,7 @@ import os import re import sqlite3 -from datetime import datetime, timedelta +import time from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from uuid import UUID, uuid4 @@ -33,10 +33,16 @@ SQL_CREATE_TABLE_NODE = """ CREATE TABLE IF NOT EXISTS node( - node_id INTEGER UNIQUE + node_id INTEGER UNIQUE, + online_until REAL, + ping_interval REAL ); """ +SQL_CREATE_INDEX_ONLINE_UNTIL = """ +CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until); +""" + SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( run_id INTEGER UNIQUE @@ -52,9 +58,10 @@ producer_node_id INTEGER, consumer_anonymous BOOLEAN, consumer_node_id INTEGER, - created_at TEXT, + created_at REAL, delivered_at TEXT, - ttl TEXT, + pushed_at REAL, + ttl REAL, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -72,9 +79,10 @@ producer_node_id INTEGER, consumer_anonymous BOOLEAN, consumer_node_id INTEGER, - created_at TEXT, + created_at REAL, delivered_at TEXT, - ttl TEXT, + pushed_at REAL, + ttl REAL, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -82,7 +90,7 @@ ); """ -DictOrTuple = Union[Tuple[Any], Dict[str, Any]] +DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]] class SqliteState(State): @@ -123,6 +131,7 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) + cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL) res = cur.execute("SELECT name FROM sqlite_schema;") return res.fetchall() @@ -185,15 +194,11 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, errors) return None - # Create task_id, created_at and ttl + # Create task_id task_id = uuid4() - created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_ins.task_id = str(task_id) - task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = ttl.isoformat() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -320,15 +325,11 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None - # Create task_id, created_at and ttl + # Create task_id task_id = uuid4() - created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_res.task_id = str(task_id) - task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = ttl.isoformat() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" @@ -472,9 +473,14 @@ def create_node(self) -> int: # Sample a random int64 as node_id node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) - query = "INSERT INTO node VALUES(:node_id);" + query = ( + "INSERT INTO node (node_id, online_until, ping_interval) VALUES (?, ?, ?)" + ) + try: - self.query(query, {"node_id": node_id}) + # Default ping interval is 30s + # TODO: change 1e9 to 30s # pylint: disable=W0511 + self.query(query, (node_id, time.time() + 1e9, 1e9)) except sqlite3.IntegrityError: log(ERROR, "Unexpected node registration failure.") return 0 @@ -499,8 +505,8 @@ def get_nodes(self, run_id: int) -> Set[int]: return set() # Get nodes - query = "SELECT * FROM node;" - rows = self.query(query) + query = "SELECT node_id FROM node WHERE online_until > ?;" + rows = self.query(query, (time.time(),)) result: Set[int] = {row["node_id"] for row in rows} return result @@ -519,6 +525,17 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: + """Acknowledge a ping received from a node, serving as a heartbeat.""" + # Update `online_until` and `ping_interval` for the given `node_id` + query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;" + try: + self.query(query, (time.time() + ping_interval, ping_interval, node_id)) + return True + except sqlite3.IntegrityError: + log(ERROR, "`node_id` does not exist.") + return False + def dict_factory( cursor: sqlite3.Cursor, @@ -544,6 +561,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: "consumer_node_id": task_msg.task.consumer.node_id, "created_at": task_msg.task.created_at, "delivered_at": task_msg.task.delivered_at, + "pushed_at": task_msg.task.pushed_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), "task_type": task_msg.task.task_type, @@ -564,6 +582,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: "consumer_node_id": task_msg.task.consumer.node_id, "created_at": task_msg.task.created_at, "delivered_at": task_msg.task.delivered_at, + "pushed_at": task_msg.task.pushed_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), "task_type": task_msg.task.task_type, @@ -592,6 +611,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: ), created_at=task_dict["created_at"], delivered_at=task_dict["delivered_at"], + pushed_at=task_dict["pushed_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), task_type=task_dict["task_type"], @@ -621,6 +641,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: ), created_at=task_dict["created_at"], delivered_at=task_dict["delivered_at"], + pushed_at=task_dict["pushed_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), task_type=task_dict["task_type"], diff --git a/src/py/flwr/server/superlink/state/sqlite_state_test.py b/src/py/flwr/server/superlink/state/sqlite_state_test.py index 9eef71e396e..20927df1cf1 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state_test.py +++ b/src/py/flwr/server/superlink/state/sqlite_state_test.py @@ -38,6 +38,7 @@ def test_ins_res_to_dict(self) -> None: "consumer_node_id", "created_at", "delivered_at", + "pushed_at", "ttl", "ancestry", "task_type", diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 9337ae6d862..313290eb102 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -152,3 +152,22 @@ def get_nodes(self, run_id: int) -> Set[int]: @abc.abstractmethod def create_run(self) -> int: """Create one run.""" + + @abc.abstractmethod + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: + """Acknowledge a ping received from a node, serving as a heartbeat. + + Parameters + ---------- + node_id : int + The `node_id` from which the ping was received. + ping_interval : float + The interval (in seconds) from the current timestamp within which the next + ping from this node must be received. This acts as a hard deadline to ensure + an accurate assessment of the node's availability. + + Returns + ------- + is_acknowledged : bool + True if the ping is successfully acknowledged; otherwise, False. + """ diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index d0470a7ce7f..1757cfac425 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -16,12 +16,15 @@ # pylint: disable=invalid-name, disable=R0904 import tempfile +import time import unittest from abc import abstractmethod from datetime import datetime, timezone from typing import List +from unittest.mock import patch from uuid import uuid4 +from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -71,9 +74,8 @@ def test_store_task_ins_one(self) -> None: consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) - assert task_ins.task.created_at == "" # pylint: disable=no-member + assert task_ins.task.created_at < time.time() # pylint: disable=no-member assert task_ins.task.delivered_at == "" # pylint: disable=no-member - assert task_ins.task.ttl == "" # pylint: disable=no-member # Execute state.store_task_ins(task_ins=task_ins) @@ -89,19 +91,13 @@ def test_store_task_ins_one(self) -> None: actual_task = actual_task_ins.task - assert actual_task.created_at != "" assert actual_task.delivered_at != "" - assert actual_task.ttl != "" - assert datetime.fromisoformat(actual_task.created_at) > datetime( - 2020, 1, 1, tzinfo=timezone.utc - ) + assert actual_task.created_at < actual_task.pushed_at assert datetime.fromisoformat(actual_task.delivered_at) > datetime( 2020, 1, 1, tzinfo=timezone.utc ) - assert datetime.fromisoformat(actual_task.ttl) > datetime( - 2020, 1, 1, tzinfo=timezone.utc - ) + assert actual_task.ttl > 0 def test_store_and_delete_tasks(self) -> None: """Test delete_tasks.""" @@ -398,6 +394,25 @@ def test_num_task_res(self) -> None: # Assert assert num == 2 + def test_acknowledge_ping(self) -> None: + """Test if acknowledge_ping works and if get_nodes return online nodes.""" + # Prepare + state: State = self.state_factory() + run_id = state.create_run() + node_ids = [state.create_node() for _ in range(100)] + for node_id in node_ids[:70]: + state.acknowledge_ping(node_id, ping_interval=30) + for node_id in node_ids[70:]: + state.acknowledge_ping(node_id, ping_interval=90) + + # Execute + current_time = time.time() + with patch("time.time", side_effect=lambda: current_time + 50): + actual_node_ids = state.get_nodes(run_id) + + # Assert + self.assertSetEqual(actual_node_ids, set(node_ids[70:])) + def create_task_ins( consumer_node_id: int, @@ -420,8 +435,11 @@ def create_task_ins( consumer=consumer, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, + created_at=time.time(), ), ) + task.task.pushed_at = time.time() return task @@ -442,8 +460,11 @@ def create_task_res( ancestry=ancestry, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, + created_at=time.time(), ), ) + task_res.task.pushed_at = time.time() return task_res @@ -477,7 +498,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 8 + assert len(result) == 9 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -502,7 +523,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 8 + assert len(result) == 9 if __name__ == "__main__": diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index f9b271beafd..c0b0ec85761 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -31,13 +31,21 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str if not tasks_ins_res.HasField("task"): validation_errors.append("`task` does not set field `task`") - # Created/delivered/TTL - if tasks_ins_res.task.created_at != "": - validation_errors.append("`created_at` must be an empty str") + # Created/delivered/TTL/Pushed + if ( + tasks_ins_res.task.created_at < 1711497600.0 + ): # unix timestamp of 27 March 2024 00h:00m:00s UTC + validation_errors.append( + "`created_at` must be a float that records the unix timestamp " + "in seconds when the message was created." + ) if tasks_ins_res.task.delivered_at != "": validation_errors.append("`delivered_at` must be an empty str") - if tasks_ins_res.task.ttl != "": - validation_errors.append("`ttl` must be an empty str") + if tasks_ins_res.task.ttl <= 0: + validation_errors.append("`ttl` must be higher than zero") + if tasks_ins_res.task.pushed_at < 1711497600.0: + # unix timestamp of 27 March 2024 00h:00m:00s UTC + validation_errors.append("`pushed_at` is not a recent timestamp") # TaskIns specific if isinstance(tasks_ins_res, TaskIns): @@ -66,8 +74,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str # Content check if tasks_ins_res.task.task_type == "": validation_errors.append("`task_type` MUST be set") - if not tasks_ins_res.task.HasField("recordset"): - validation_errors.append("`recordset` MUST be set") + if not ( + tasks_ins_res.task.HasField("recordset") + ^ tasks_ins_res.task.HasField("error") + ): + validation_errors.append("Either `recordset` or `error` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -106,8 +117,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str # Content check if tasks_ins_res.task.task_type == "": validation_errors.append("`task_type` MUST be set") - if not tasks_ins_res.task.HasField("recordset"): - validation_errors.append("`recordset` MUST be set") + if not ( + tasks_ins_res.task.HasField("recordset") + ^ tasks_ins_res.task.HasField("error") + ): + validation_errors.append("Either `recordset` or `error` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 8e084950802..61fe094c23d 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -15,9 +15,11 @@ """Validator tests.""" +import time import unittest from typing import List, Tuple +from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -96,8 +98,12 @@ def create_task_ins( consumer=consumer, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, + created_at=time.time(), ), ) + + task.task.pushed_at = time.time() return task @@ -117,6 +123,10 @@ def create_task_res( ancestry=ancestry, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, + created_at=time.time(), ), ) + + task_res.task.pushed_at = time.time() return task_res diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index 876ae56dcad..42b1151f983 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -21,7 +21,7 @@ from typing import Optional, cast import flwr.common.recordset_compat as compat -from flwr.common import ConfigsRecord, Context, GetParametersIns, log +from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, GetParametersIns, log from flwr.common.constant import MessageType, MessageTypeLegacy from ..compat.app_utils import start_update_client_manager_thread @@ -127,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: message_type=MessageTypeLegacy.GET_PARAMETERS, dst_node_id=random_client.node_id, group_id="0", - ttl="", + ttl=DEFAULT_TTL, ) ] ) @@ -226,7 +226,7 @@ def default_fit_workflow( # pylint: disable=R0914 message_type=MessageType.TRAIN, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) for proxy, fitins in client_instructions ] @@ -306,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: message_type=MessageType.EVALUATE, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) for proxy, evalins in client_instructions ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index 42ee9c15f1c..326947b653f 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -22,6 +22,7 @@ import flwr.common.recordset_compat as compat from flwr.common import ( + DEFAULT_TTL, ConfigsRecord, Context, FitRes, @@ -373,7 +374,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) log( @@ -421,7 +422,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) # Broadcast public keys to clients and receive secret key shares @@ -492,7 +493,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) log( @@ -563,7 +564,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) log( diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 08d0576e39f..9773203628a 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -493,13 +493,17 @@ async def submit( self._future_to_actor[future] = actor return future + async def add_actor_back_to_pool(self, future: Any) -> None: + """Ad actor assigned to run future back into the pool.""" + actor = self._future_to_actor.pop(future) + await self.pool.put(actor) + async def fetch_result_and_return_actor_to_pool( self, future: Any ) -> Tuple[Message, Context]: """Pull result given a future and add actor back to pool.""" # Get actor that ran job - actor = self._future_to_actor.pop(future) - await self.pool.put(actor) + await self.add_actor_back_to_pool(future) # Retrieve result for object store # Instead of doing ray.get(future) we await it _, out_mssg, updated_context = await future diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index c3493163ac5..5e344eb087e 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -23,7 +23,7 @@ from flwr.client import ClientFn from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState -from flwr.common import Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.logger import log from flwr.common.recordset_compat import ( @@ -105,7 +105,7 @@ def _wrap_recordset_in_message( src_node_id=0, dst_node_id=int(self.cid), reply_to_message="", - ttl=str(timeout) if timeout else "", + ttl=timeout if timeout else DEFAULT_TTL, message_type=message_type, partition_id=int(self.cid), ), diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 22c5425cd9f..9680b3846f1 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -24,6 +24,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp from flwr.common import ( + DEFAULT_TTL, Config, ConfigsRecord, Context, @@ -202,7 +203,7 @@ def _load_app() -> ClientApp: src_node_id=0, dst_node_id=12345, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, partition_id=int(cid), ),