Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tanertopal committed Mar 6, 2024
1 parent 44c0098 commit 81e1c7b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 43 deletions.
48 changes: 24 additions & 24 deletions src/py/flwr/cli/flower_toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,38 @@ def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
return data


def validate_flower_toml_fields(config: Dict[str, Any]) -> Tuple[bool, List[str]]:
def validate_flower_toml_fields(
config: Dict[str, Any]
) -> Tuple[bool, List[str], List[str]]:
"""Validate flower.toml fields."""
invalid_reasons = []
errors = []
warnings = []

if "project" not in config:
invalid_reasons.append("Missing [project] section")
errors.append("Missing [project] section")
else:
if "name" not in config["project"]:
invalid_reasons.append('Property "name" missing in [project]')
errors.append('Property "name" missing in [project]')
if "version" not in config["project"]:
invalid_reasons.append('Property "version" missing in [project]')
errors.append('Property "version" missing in [project]')
if "description" not in config["project"]:
invalid_reasons.append('Property "description" missing in [project]')
warnings.append('Recommended property "description" missing in [project]')
if "license" not in config["project"]:
invalid_reasons.append('Property "license" missing in [project]')
warnings.append('Recommended property "license" missing in [project]')
if "authors" not in config["project"]:
invalid_reasons.append('Property "authors" missing in [project]')
warnings.append('Recommended property "authors" missing in [project]')

if "flower" not in config:
invalid_reasons.append("Missing [flower] section")
errors.append("Missing [flower] section")
elif "components" not in config["flower"]:
invalid_reasons.append("Missing [flower.components] section")
errors.append("Missing [flower.components] section")
else:
if "serverapp" not in config["flower"]["components"]:
invalid_reasons.append(
'Property "serverapp" missing in [flower.components]'
)
errors.append('Property "serverapp" missing in [flower.components]')
if "clientapp" not in config["flower"]["components"]:
invalid_reasons.append(
'Property "clientapp" missing in [flower.components]'
)
errors.append('Property "clientapp" missing in [flower.components]')

return len(invalid_reasons) == 0, invalid_reasons
return len(errors) == 0, errors, warnings


def validate_object_reference(ref: str) -> Tuple[bool, Optional[str]]:
Expand All @@ -78,7 +77,8 @@ def validate_object_reference(ref: str) -> Tuple[bool, Optional[str]]:
Returns
-------
Tuple[bool, Optional[str]]
A boolean indicating whether an object reference is valid and the reason why it might not be.
A boolean indicating whether an object reference is valid and
the reason why it might not be.
"""
module_str, _, attributes_str = ref.partition(":")
if not module_str:
Expand Down Expand Up @@ -112,29 +112,29 @@ def validate_object_reference(ref: str) -> Tuple[bool, Optional[str]]:
return (True, None)


def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str]]:
def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
"""Validate flower.toml."""
is_valid, reasons = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_flower_toml_fields(config)

if not is_valid:
return False, reasons
return False, errors, warnings

# Validate serverapp
is_valid, reason = validate_object_reference(
config["flower"]["components"]["serverapp"]
)
if not is_valid and isinstance(reason, str):
return False, [reason]
return False, [reason], []

# Validate clientapp
is_valid, reason = validate_object_reference(
config["flower"]["components"]["clientapp"]
)

if not is_valid and isinstance(reason, str):
return False, [reason]
return False, [reason], []

return True, []
return True, [], []


def apply_defaults(
Expand Down
41 changes: 24 additions & 17 deletions src/py/flwr/cli/flower_toml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_load_flower_toml_load_from_cwd(tmp_path: str) -> None:
},
"engine": {
"name": "simulation",
"simulation": {"super-node": {"count": 10}},
"simulation": {"supernode": {"count": 10}},
},
},
}
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_load_flower_toml_from_path(tmp_path: str) -> None:
[flower.engine]
name = "simulation" # optional
[flower.engine.simulation.super-node]
[flower.engine.simulation.supernode]
count = 10 # optional
"""
expected_config = {
Expand All @@ -105,7 +105,7 @@ def test_load_flower_toml_from_path(tmp_path: str) -> None:
},
"engine": {
"name": "simulation",
"simulation": {"super-node": {"count": 10}},
"simulation": {"supernode": {"count": 10}},
},
},
}
Expand Down Expand Up @@ -134,11 +134,12 @@ def test_validate_flower_toml_fields_empty() -> None:
config: Dict[str, Any] = {}

# Execute
is_valid, reasons = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_flower_toml_fields(config)

# Assert
assert not is_valid
assert len(reasons) == 2
assert len(errors) == 2
assert len(warnings) == 0


def test_validate_flower_toml_fields_no_flower() -> None:
Expand All @@ -155,11 +156,12 @@ def test_validate_flower_toml_fields_no_flower() -> None:
}

# Execute
is_valid, reasons = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_flower_toml_fields(config)

# Assert
assert not is_valid
assert len(reasons) == 1
assert len(errors) == 1
assert len(warnings) == 0


def test_validate_flower_toml_fields_no_flower_components() -> None:
Expand All @@ -177,11 +179,12 @@ def test_validate_flower_toml_fields_no_flower_components() -> None:
}

# Execute
is_valid, reasons = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_flower_toml_fields(config)

# Assert
assert not is_valid
assert len(reasons) == 1
assert len(errors) == 1
assert len(warnings) == 0


def test_validate_flower_toml_fields_no_server_and_client_app() -> None:
Expand All @@ -199,11 +202,12 @@ def test_validate_flower_toml_fields_no_server_and_client_app() -> None:
}

# Execute
is_valid, reasons = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_flower_toml_fields(config)

# Assert
assert not is_valid
assert len(reasons) == 2
assert len(errors) == 2
assert len(warnings) == 0


def test_validate_flower_toml_fields() -> None:
Expand All @@ -221,11 +225,12 @@ def test_validate_flower_toml_fields() -> None:
}

# Execute
is_valid, reasons = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_flower_toml_fields(config)

# Assert
assert is_valid
assert len(reasons) == 0
assert len(errors) == 0
assert len(warnings) == 0


def test_validate_object_reference() -> None:
Expand Down Expand Up @@ -274,11 +279,12 @@ def test_validate_flower_toml() -> None:
}

# Execute
is_valid, reasons = validate_flower_toml(config)
is_valid, errors, warnings = validate_flower_toml(config)

# Assert
assert is_valid
assert not reasons
assert not errors
assert not warnings


def test_validate_flower_toml_fail() -> None:
Expand All @@ -301,8 +307,9 @@ def test_validate_flower_toml_fail() -> None:
}

# Execute
is_valid, reasons = validate_flower_toml(config)
is_valid, errors, warnings = validate_flower_toml(config)

# Assert
assert not is_valid
assert len(reasons) == 1
assert len(errors) == 1
assert len(warnings) == 0
15 changes: 13 additions & 2 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,23 @@ def run() -> None:
)
sys.exit()

is_valid, reasons = validate_flower_toml(config)
is_valid, errors, warnings = validate_flower_toml(config)
if warnings:
print(
typer.style(
"Project configuration is missing the following "
"recommended properties:\n"
+ "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)
)

if not is_valid:
print(
typer.style(
"Project configuration could not be loaded.\nflower.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in reasons]),
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
Expand Down

0 comments on commit 81e1c7b

Please sign in to comment.