Skip to content

Commit

Permalink
Handle ClientApp exception (#2846)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
Co-authored-by: jafermarq <[email protected]>
  • Loading branch information
danieljanes and jafermarq authored Mar 28, 2024
1 parent 3f282d4 commit 540adef
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 39 deletions.
33 changes: 21 additions & 12 deletions examples/app-pytorch/server_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,19 @@ def main(driver: Driver, context: Context) -> None:
all_replies: List[Message] = []
while True:
replies = driver.pull_messages(message_ids=message_ids)
print(f"Got {len(replies)} results")
for res in replies:
print(f"Got 1 {'result' if res.has_content() else 'error'}")
all_replies += replies
if len(all_replies) == len(message_ids):
break
print("Pulling messages...")
time.sleep(3)

# Collect correct results
# Filter correct results
all_fitres = [
recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies
recordset_to_fitres(msg.content, keep_input=True)
for msg in all_replies
if msg.has_content()
]
print(f"Received {len(all_fitres)} results")

Expand All @@ -128,16 +132,21 @@ def main(driver: Driver, context: Context) -> None:
)
metrics_results.append((fitres.num_examples, fitres.metrics))

# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated
if len(weights_results) > 0:
# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated

# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
else:
print(
f"Round {server_round} got {len(weights_results)} results. Skipping aggregation..."
)

# Slow down the start of the next round
time.sleep(sleep_time)
Expand Down
59 changes: 35 additions & 24 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# ==============================================================================
"""Flower client app."""


import argparse
import sys
import time
from logging import DEBUG, INFO, WARN
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Optional, Tuple, Type, Union

Expand All @@ -38,6 +37,7 @@
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
from flwr.common.message import Error
from flwr.common.object_ref import load_app, validate
from flwr.common.retry_invoker import RetryInvoker, exponential

Expand Down Expand Up @@ -482,32 +482,43 @@ def _load_client_app() -> ClientApp:
# Retrieve context for this run
context = node_state.retrieve_context(run_id=message.metadata.run_id)

# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()
# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
reply_message = message.create_error_reply(
error=Error(code=0, reason="Unknown")
)

# Handle task message
out_message = client_app(message=message, context=context)
# Handle app loading and task message
try:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()

# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
reply_message = client_app(message=message, context=context)
# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, "ClientApp raised an exception", exc_info=ex)

# Legacy grpc-bidi
if transport in ["grpc-bidi", None]:
# Raise exception, crash process
raise ex

# Don't update/change NodeState

# Create error message
# Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
reply_message = message.create_error_reply(
error=Error(code=0, reason=reason)
)

# Send
send(out_message)
log(
INFO,
"[RUN %s, ROUND %s]",
out_message.metadata.run_id,
out_message.metadata.group_id,
)
log(
INFO,
"Sent: %s reply to message %s",
out_message.metadata.message_type,
message.metadata.message_id,
)
send(reply_message)
log(INFO, "Sent reply")

# Unregister node
if delete_node is not None:
Expand Down
16 changes: 16 additions & 0 deletions src/py/flwr/server/compat/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,24 @@ def _send_receive_recordset(
)
if len(task_res_list) == 1:
task_res = task_res_list[0]

# This will raise an Exception if task_res carries an `error`
validate_task_res(task_res=task_res)

return serde.recordset_from_proto(task_res.task.recordset)

if timeout is not None and time.time() > start_time + timeout:
raise RuntimeError("Timeout reached")
time.sleep(SLEEP_TIME)


def validate_task_res(
task_res: task_pb2.TaskRes, # pylint: disable=E1101
) -> None:
"""Validate if a TaskRes is empty or not."""
if not task_res.HasField("task"):
raise ValueError("Invalid TaskRes, field `task` missing")
if task_res.task.HasField("error"):
raise ValueError("Exception during client-side task execution")
if not task_res.task.HasField("recordset"):
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
85 changes: 82 additions & 3 deletions src/py/flwr/server/compat/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@
Properties,
Status,
)
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611

from .driver_client_proxy import DriverClientProxy
from flwr.proto import ( # pylint: disable=E0611
driver_pb2,
error_pb2,
node_pb2,
recordset_pb2,
task_pb2,
)
from flwr.server.compat.driver_client_proxy import DriverClientProxy, validate_task_res

MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np")

Expand Down Expand Up @@ -243,3 +248,77 @@ def test_evaluate(self) -> None:
# Assert
assert 0.0 == evaluate_res.loss
assert 0 == evaluate_res.num_examples

def test_validate_task_res_valid(self) -> None:
"""Test valid TaskRes."""
metrics_record = recordset_pb2.MetricsRecord( # pylint: disable=E1101
data={
"loss": recordset_pb2.MetricsRecordValue( # pylint: disable=E1101
double=1.0
)
}
)
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
recordset=recordset_pb2.RecordSet( # pylint: disable=E1101
parameters={},
metrics={"loss": metrics_record},
configs={},
)
),
)

# Execute & assert
try:
validate_task_res(task_res=task_res)
except ValueError:
self.fail()

def test_validate_task_res_missing_task(self) -> None:
"""Test invalid TaskRes (missing task)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)

def test_validate_task_res_missing_recordset(self) -> None:
"""Test invalid TaskRes (missing recordset)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task(), # pylint: disable=E1101
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)

def test_validate_task_res_missing_content(self) -> None:
"""Test invalid TaskRes (missing content)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
error=error_pb2.Error( # pylint: disable=E1101
code=0,
reason="Some reason",
)
),
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)

0 comments on commit 540adef

Please sign in to comment.