Skip to content

Commit

Permalink
feat(framework) Add FAB hash to load_client_app_fn (#4305)
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng authored Oct 10, 2024
1 parent 2b9dbcc commit b857425
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 34 deletions.
4 changes: 1 addition & 3 deletions src/py/flwr/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ def validate_and_install(
install_dir: Path = (
(get_flwr_dir() if not flwr_dir else flwr_dir)
/ "apps"
/ publisher
/ project_name
/ version
/ f"{publisher}.{project_name}.{version}.{fab_hash[:FAB_HASH_TRUNCATION]}"
)
if install_dir.exists():
if skip_prompt:
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def start_client_internal(
*,
server_address: str,
node_config: UserConfig,
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
load_client_app_fn: Optional[Callable[[str, str, str], ClientApp]] = None,
client_fn: Optional[ClientFnExt] = None,
client: Optional[Client] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
Expand Down Expand Up @@ -298,7 +298,7 @@ def single_client_factory(

client_fn = single_client_factory

def _load_client_app(_1: str, _2: str) -> ClientApp:
def _load_client_app(_1: str, _2: str, _3: str) -> ClientApp:
return ClientApp(client_fn=client_fn)

load_client_app_fn = _load_client_app
Expand Down Expand Up @@ -529,7 +529,7 @@ def _on_backoff(retry_state: RetryState) -> None:
else:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn(
fab_id, fab_version
fab_id, fab_version, run.fab_hash
)

# Execute ClientApp
Expand Down
7 changes: 5 additions & 2 deletions src/py/flwr/client/clientapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ def run_clientapp( # pylint: disable=R0914
)

try:
# Load ClientApp
client_app: ClientApp = load_client_app_fn(run.fab_id, run.fab_version)
if fab:
# Load ClientApp
client_app: ClientApp = load_client_app_fn(
run.fab_id, run.fab_version, fab.hash_str
)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
Expand Down
16 changes: 11 additions & 5 deletions src/py/flwr/client/clientapp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_load_client_app_fn(
app_path: Optional[str],
multi_app: bool,
flwr_dir: Optional[str] = None,
) -> Callable[[str, str], ClientApp]:
) -> Callable[[str, str, str], ClientApp]:
"""Get the load_client_app_fn function.
If `multi_app` is True, this function loads the specified ClientApp
Expand All @@ -55,13 +55,14 @@ def get_load_client_app_fn(
if not valid and error_msg:
raise LoadClientAppError(error_msg) from None

def _load(fab_id: str, fab_version: str) -> ClientApp:
def _load(fab_id: str, fab_version: str, fab_hash: str) -> ClientApp:
runtime_app_dir = Path(app_path if app_path else "").absolute()
# If multi-app feature is disabled
if not multi_app:
# Set app reference
client_app_ref = default_app_ref
# If multi-app feature is enabled but app directory is provided
# If multi-app feature is enabled but app directory is provided.
# `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
elif app_path is not None:
config = get_project_config(runtime_app_dir)
this_fab_version, this_fab_id = get_metadata_from_config(config)
Expand All @@ -81,11 +82,16 @@ def _load(fab_id: str, fab_version: str) -> ClientApp:
else:
try:
runtime_app_dir = get_project_dir(
fab_id, fab_version, get_flwr_dir(flwr_dir)
fab_id, fab_version, fab_hash, get_flwr_dir(flwr_dir)
)
config = get_project_config(runtime_app_dir)
except Exception as e:
raise LoadClientAppError("Failed to load ClientApp") from e
raise LoadClientAppError(
"Failed to load ClientApp."
"Possible reasons for error include mismatched "
"`fab_id`, `fab_version`, or `fab_hash` in "
f"{str(get_flwr_dir(flwr_dir).resolve())}."
) from e

# Set app reference
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
Expand Down
20 changes: 16 additions & 4 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import tomli

from flwr.cli.config_utils import get_fab_config, validate_fields
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.constant import (
APP_DIR,
FAB_CONFIG_FILE,
FAB_HASH_TRUNCATION,
FLWR_HOME,
)
from flwr.common.typing import Run, UserConfig, UserConfigValue


Expand All @@ -39,7 +44,10 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:


def get_project_dir(
fab_id: str, fab_version: str, flwr_dir: Optional[Union[str, Path]] = None
fab_id: str,
fab_version: str,
fab_hash: str,
flwr_dir: Optional[Union[str, Path]] = None,
) -> Path:
"""Return the project directory based on the given fab_id and fab_version."""
# Check the fab_id
Expand All @@ -50,7 +58,11 @@ def get_project_dir(
publisher, project_name = fab_id.split("/")
if flwr_dir is None:
flwr_dir = get_flwr_dir()
return Path(flwr_dir) / APP_DIR / publisher / project_name / fab_version
return (
Path(flwr_dir)
/ APP_DIR
/ f"{publisher}.{project_name}.{fab_version}.{fab_hash[:FAB_HASH_TRUNCATION]}"
)


def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
Expand Down Expand Up @@ -127,7 +139,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
if not run.fab_id or not run.fab_version:
return {}

project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
project_dir = get_project_dir(run.fab_id, run.fab_version, run.fab_hash, flwr_dir)

# Return empty dict if project directory does not exist
if not project_dir.is_dir():
Expand Down
15 changes: 12 additions & 3 deletions src/py/flwr/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,22 @@ def test_get_flwr_dir_with_xdg_data_home() -> None:
def test_get_project_dir_invalid_fab_id() -> None:
"""Test get_project_dir with an invalid fab_id."""
with pytest.raises(ValueError):
get_project_dir("invalid_fab_id", "1.0.0")
get_project_dir(
"invalid_fab_id",
"1.0.0",
"03840e932bf61247c1231f0aec9e8ec5f041ed5516fb23638f24d25f3a007acd",
)


def test_get_project_dir_valid() -> None:
"""Test get_project_dir with an valid fab_id and version."""
app_path = get_project_dir("app_name/user", "1.0.0", flwr_dir=".")
assert app_path == Path("apps") / "app_name" / "user" / "1.0.0"
app_path = get_project_dir(
"app_name/user",
"1.0.0",
"03840e932bf61247c1231f0aec9e8ec5f041ed5516fb23638f24d25f3a007acd",
flwr_dir=".",
)
assert app_path == Path("apps") / "app_name.user.1.0.0.03840e93"


def test_get_project_config_file_not_found() -> None:
Expand Down
24 changes: 11 additions & 13 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,17 @@ def run_server_app() -> None:
)
flwr_dir = get_flwr_dir(args.flwr_dir)
run_ = driver.run
if run_.fab_hash:
fab_req = GetFabRequest(hash_str=run_.fab_hash)
# pylint: disable-next=W0212
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
if fab_res.fab.hash_str != run_.fab_hash:
raise ValueError("FAB hashes don't match.")

install_from_fab(fab_res.fab.content, flwr_dir, True)
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
else:
fab_id, fab_version = run_.fab_id, run_.fab_version

app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
if not run_.fab_hash:
raise ValueError("FAB hash not provided.")
fab_req = GetFabRequest(hash_str=run_.fab_hash)
# pylint: disable-next=W0212
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
if fab_res.fab.hash_str != run_.fab_hash:
raise ValueError("FAB hashes don't match.")
install_from_fab(fab_res.fab.content, flwr_dir, True)
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)

app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
config = get_project_config(app_path)
else:
# User provided `app_dir`, but not `--run-id`
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def _load() -> ClientApp:
app_path=app_dir,
flwr_dir=flwr_dir,
multi_app=False,
)(run.fab_id, run.fab_version)
)(run.fab_id, run.fab_version, run.fab_hash)

if client_app:
app = client_app
Expand Down

0 comments on commit b857425

Please sign in to comment.