Skip to content

Commit

Permalink
Handle ClientApp exception simulation (#3075)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Mar 28, 2024
1 parent 67ca7ab commit 531e0e3
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 59 deletions.
9 changes: 5 additions & 4 deletions src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import ray

from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.client_app import ClientApp
from flwr.common.context import Context
from flwr.common.logger import log
from flwr.common.message import Message
Expand Down Expand Up @@ -151,7 +151,6 @@ async def process_message(
)

await future

# Fetch result
(
out_mssg,
Expand All @@ -160,13 +159,15 @@ async def process_message(

return out_mssg, updated_context

except LoadClientAppError as load_ex:
except Exception as ex:
log(
ERROR,
"An exception was raised when processing a message by %s",
self.__class__.__name__,
)
raise load_ex
# add actor back into pool
await self.pool.add_actor_back_to_pool(future)
raise ex

async def terminate(self) -> None:
"""Terminate all actors in actor pool."""
Expand Down
66 changes: 41 additions & 25 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
# ==============================================================================
"""Fleet Simulation Engine API."""


import asyncio
import json
import sys
import time
import traceback
from logging import DEBUG, ERROR, INFO, WARN
from typing import Callable, Dict, List, Optional

from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.node_state import NodeState
from flwr.common.logger import log
from flwr.common.message import Error
from flwr.common.object_ref import load_app
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
Expand Down Expand Up @@ -59,6 +61,7 @@ async def worker(
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
state = state_factory.state()
while True:
out_mssg = None
try:
task_ins: TaskIns = await queue.get()
node_id = task_ins.task.consumer.node_id
Expand All @@ -82,24 +85,25 @@ async def worker(
task_ins.run_id, context=updated_context
)

# Convert to TaskRes
task_res = message_to_taskres(out_mssg)
# Store TaskRes in state
state.store_task_res(task_res)

except asyncio.CancelledError as e:
log(DEBUG, "Async worker: %s", e)
log(DEBUG, "Terminating async worker: %s", e)
break

except LoadClientAppError as app_ex:
log(ERROR, "Async worker: %s", app_ex)
log(ERROR, traceback.format_exc())
raise

# Exceptions aren't raised but reported as an error message
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, ex)
log(ERROR, traceback.format_exc())
break
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
error = Error(code=0, reason=reason)
out_mssg = message.create_error_reply(error=error)

finally:
if out_mssg:
# Convert to TaskRes
task_res = message_to_taskres(out_mssg)
# Store TaskRes in state
task_res.task.pushed_at = time.time()
state.store_task_res(task_res)


async def add_taskins_to_queue(
Expand Down Expand Up @@ -218,7 +222,7 @@ async def run(
await backend.terminate()


# pylint: disable=too-many-arguments,unused-argument,too-many-locals
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
def start_vce(
backend_name: str,
backend_config_json_stream: str,
Expand Down Expand Up @@ -300,12 +304,14 @@ def backend_fn() -> Backend:
"""Instantiate a Backend."""
return backend_type(backend_config, work_dir=app_dir)

log(INFO, "client_app_attr = %s", client_app_attr)

# Load ClientApp if needed
def _load() -> ClientApp:

if client_app_attr:

if app_dir is not None:
sys.path.insert(0, app_dir)

app: ClientApp = load_app(client_app_attr, LoadClientAppError)

if not isinstance(app, ClientApp):
Expand All @@ -319,13 +325,23 @@ def _load() -> ClientApp:

app_fn = _load

asyncio.run(
run(
app_fn,
backend_fn,
nodes_mapping,
state_factory,
node_states,
f_stop,
try:
# Test if ClientApp can be loaded
_ = app_fn()

# Run main simulation loop
asyncio.run(
run(
app_fn,
backend_fn,
nodes_mapping,
state_factory,
node_states,
f_stop,
)
)
)
except LoadClientAppError as loadapp_ex:
f_stop.set() # set termination event
raise loadapp_ex
except Exception as ex:
raise ex
31 changes: 3 additions & 28 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from unittest import IsolatedAsyncioTestCase
from uuid import UUID

from flwr.client.client_app import LoadClientAppError
from flwr.common import (
DEFAULT_TTL,
GetPropertiesIns,
Expand All @@ -53,7 +54,6 @@ def terminate_simulation(f_stop: asyncio.Event, sleep_duration: int) -> None:
def init_state_factory_nodes_mapping(
num_nodes: int,
num_messages: int,
erroneous_message: Optional[bool] = False,
) -> Tuple[StateFactory, NodeToPartitionMapping, Dict[UUID, float]]:
"""Instatiate StateFactory, register nodes and pre-insert messages in the state."""
# Register a state and a run_id in it
Expand All @@ -68,7 +68,6 @@ def init_state_factory_nodes_mapping(
nodes_mapping=nodes_mapping,
run_id=run_id,
num_messages=num_messages,
erroneous_message=erroneous_message,
)
return state_factory, nodes_mapping, expected_results

Expand All @@ -79,7 +78,6 @@ def register_messages_into_state(
nodes_mapping: NodeToPartitionMapping,
run_id: int,
num_messages: int,
erroneous_message: Optional[bool] = False,
) -> Dict[UUID, float]:
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
Expand All @@ -105,11 +103,7 @@ def register_messages_into_state(
dst_node_id=dst_node_id, # indicate destination node
reply_to_message="",
ttl=DEFAULT_TTL,
message_type=(
"a bad message"
if erroneous_message
else MessageTypeLegacy.GET_PROPERTIES
),
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
)
# Convert Message to TaskIns
Expand Down Expand Up @@ -200,32 +194,13 @@ def test_erroneous_client_app_attr(self) -> None:
state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping(
num_nodes=num_nodes, num_messages=num_messages
)
with self.assertRaises(RuntimeError):
with self.assertRaises(LoadClientAppError):
start_and_shutdown(
client_app_attr="totally_fictitious_app:client",
state_factory=state_factory,
nodes_mapping=nodes_mapping,
)

def test_erroneous_messages(self) -> None:
"""Test handling of error in async worker (consumer).
We register messages which will trigger an error when handling, triggering an
error.
"""
num_messages = 100
num_nodes = 59

state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping(
num_nodes=num_nodes, num_messages=num_messages, erroneous_message=True
)

with self.assertRaises(RuntimeError):
start_and_shutdown(
state_factory=state_factory,
nodes_mapping=nodes_mapping,
)

def test_erroneous_backend_config(self) -> None:
"""Backend Config should be a JSON stream."""
with self.assertRaises(JSONDecodeError):
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,17 @@ async def submit(
self._future_to_actor[future] = actor
return future

async def add_actor_back_to_pool(self, future: Any) -> None:
"""Ad actor assigned to run future back into the pool."""
actor = self._future_to_actor.pop(future)
await self.pool.put(actor)

async def fetch_result_and_return_actor_to_pool(
self, future: Any
) -> Tuple[Message, Context]:
"""Pull result given a future and add actor back to pool."""
# Get actor that ran job
actor = self._future_to_actor.pop(future)
await self.pool.put(actor)
await self.add_actor_back_to_pool(future)
# Retrieve result for object store
# Instead of doing ray.get(future) we await it
_, out_mssg, updated_context = await future
Expand Down

0 comments on commit 531e0e3

Please sign in to comment.