Skip to content

Commit

Permalink
test: clean up and test gate dataclass
Browse files Browse the repository at this point in the history
Signed-off-by: Will Murphy <[email protected]>
  • Loading branch information
willmurphyscode committed Sep 18, 2024
1 parent 9d8194e commit 11c2fc0
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 24 deletions.
57 changes: 33 additions & 24 deletions src/yardstick/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ def outcome(self) -> DeltaType:

@dataclass
class Gate:
# label_comparisons: InitVar[Optional[list[comparison.AgainstLabels]]]
reference_comparison: InitVar[Optional[comparison.AgainstLabels]]
candidate_comparison: InitVar[Optional[comparison.AgainstLabels]]
# label_comparison_stats: InitVar[Optional[comparison.ImageToolLabelStats]]
reference_comparison: InitVar[Optional[comparison.LabelComparisonSummary]]
candidate_comparison: InitVar[Optional[comparison.LabelComparisonSummary]]

config: GateConfig

Expand All @@ -119,33 +117,34 @@ class Gate:

def __post_init__(
self,
# label_comparisons: Optional[list[comparison.AgainstLabels]],
reference_comparison: Optional[comparison.AgainstLabels],
candidate_comparison: Optional[comparison.AgainstLabels],
reference_comparison: Optional[comparison.LabelComparisonSummary],
candidate_comparison: Optional[comparison.LabelComparisonSummary],
):
if not reference_comparison or not candidate_comparison:
return

reasons = []

reference_f1_score = reference_comparison.summary.f1_score
current_f1_score = candidate_comparison.summary.f1_score
reference_f1_score = reference_comparison.f1_score
current_f1_score = candidate_comparison.f1_score
if current_f1_score < reference_f1_score - self.config.max_f1_regression:
reasons.append(
f"current F1 score is lower than the latest release F1 score: {bcolors.BOLD + bcolors.UNDERLINE}candidate_score={current_f1_score:0.2f} reference_score={reference_f1_score:0.2f}{bcolors.RESET} image={image}"
f"current F1 score is lower than the latest release F1 score: {bcolors.BOLD + bcolors.UNDERLINE}candidate_score={current_f1_score:0.2f} reference_score={reference_f1_score:0.2f}{bcolors.RESET} image={self.input_description.image}"
)

if candidate_comparison.summary.indeterminate_percent > self.config.max_unlabeled_percent:
if (
candidate_comparison.indeterminate_percent
> self.config.max_unlabeled_percent
):
reasons.append(
f"current indeterminate matches % is greater than {self.config.max_unlabeled_percent}%: {bcolors.BOLD + bcolors.UNDERLINE}candidate={comp.summary.indeterminate_percent:0.2f}%{bcolors.RESET} image={image}"
f"current indeterminate matches % is greater than {self.config.max_unlabeled_percent}%: {bcolors.BOLD + bcolors.UNDERLINE}candidate={candidate_comparison.indeterminate_percent:0.2f}%{bcolors.RESET} image={self.input_description.image}"
)


reference_fns = reference_comparison.summary.false_negatives
candidate_fns = candidate_comparison.summary.false_negatives
reference_fns = reference_comparison.false_negatives
candidate_fns = candidate_comparison.false_negatives
if candidate_fns > reference_fns + self.config.max_new_false_negatives:
reasons.append(
f"current false negatives is greater than the latest release false negatives: {bcolors.BOLD + bcolors.UNDERLINE}candidate={candidate_fns} reference={reference_fns}{bcolors.RESET} image={image}"
f"current false negatives is greater than the latest release false negatives: {bcolors.BOLD + bcolors.UNDERLINE}candidate={candidate_fns} reference={reference_fns}{bcolors.RESET} image={self.input_description.image}"
)

self.reasons = reasons
Expand All @@ -158,7 +157,8 @@ def failing(cls, reasons: list[str], input_description: GateInputDescription):
"""failing bypasses Gate's normal validation calculating and returns a
gate that is failing for the reasons given."""
return cls(
reference_comparison=None, candidate_comparison=None,
reference_comparison=None,
candidate_comparison=None,
config=GateConfig(),
reasons=reasons,
input_description=input_description,
Expand All @@ -168,7 +168,8 @@ def failing(cls, reasons: list[str], input_description: GateInputDescription):
def passing(cls, input_description: GateInputDescription):
"""passing bypasses a Gate's normal validation and returns a gate that is passing."""
return cls(
reference_comparison=None, candidate_comparison=None,
reference_comparison=None,
candidate_comparison=None,
config=GateConfig(),
reasons=[], # a gate with no reason to fail is considered passing
input_description=input_description,
Expand Down Expand Up @@ -413,14 +414,20 @@ def validate_image(
)

if len(results) != 2:
raise RuntimeError(f"validate_image compares results of exactly 2 runs, but found{len(results)}")
raise RuntimeError(
f"validate_image compares results of exactly 2 runs, but found{len(results)}"
)

candidate_tool, reference_tool = tool_designations(gate_config.candidate_tool_label, [r.config for r in results])
candidate_tool, reference_tool = tool_designations(
gate_config.candidate_tool_label, [r.config for r in results]
)

# keep a list of differences between tools to summarize in UI
# not that this is different from the statistical comparison;
# deltas basically a UI/logging concern; the stats are a pass/fail concern.
deltas = compute_deltas(comparisons_by_result_id, reference_tool, relative_comparison)
deltas = compute_deltas(
comparisons_by_result_id, reference_tool, relative_comparison
)

reference_comparisons_by_images = {
comp.config.image: comp
Expand All @@ -435,15 +442,17 @@ def validate_image(
}
candidate_comparison = candidate_comparisons_by_images[image]
return Gate(
reference_comparison=reference_comparison,
candidate_comparison=candidate_comparison,
reference_comparison=reference_comparison.summary,
candidate_comparison=candidate_comparison.summary,
config=gate_config,
input_description=results_used(image, relative_comparison.results),
deltas=deltas,
)


def tool_designations(candidate_tool_label: str, scan_configs: list[artifact.ScanConfiguration]) -> tuple[str, str]:
def tool_designations(
candidate_tool_label: str, scan_configs: list[artifact.ScanConfiguration]
) -> tuple[str, str]:
reference_tool, candidate_tool = None, None
if not candidate_tool_label:
reference_tool, candidate_tool = guess_tool_orientation(
Expand Down
97 changes: 97 additions & 0 deletions tests/unit/test_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from yardstick.validate import Gate, GateConfig, GateInputDescription, Delta
from yardstick import artifact, comparison


import pytest
from unittest.mock import MagicMock


@pytest.fixture
def mock_label_comparison():
"""Fixture to create a mock LabelComparisonSummary with defaults."""
summary = MagicMock()
summary.f1_score = 0.9
summary.false_negatives = 5
summary.indeterminate_percent = 2.0
return summary


@pytest.mark.parametrize(
"config, reference_summary, candidate_summary, expected_reasons",
[
# Case 1: Candidate has a lower F1 score beyond the allowed threshold -> gate fails
(
GateConfig(
max_f1_regression=0.1,
max_new_false_negatives=5,
max_unlabeled_percent=10,
),
MagicMock(f1_score=0.9, false_negatives=5, indeterminate_percent=2.0),
MagicMock(f1_score=0.7, false_negatives=5, indeterminate_percent=2.0),
["current F1 score is lower than the latest release F1 score"],
),
# Case 2: Candidate has too many false negatives -> gate fails
(
GateConfig(
max_f1_regression=0.1,
max_new_false_negatives=1,
max_unlabeled_percent=10,
),
MagicMock(f1_score=0.9, false_negatives=5, indeterminate_percent=2.0),
MagicMock(f1_score=0.85, false_negatives=7, indeterminate_percent=2.0),
[
"current false negatives is greater than the latest release false negatives"
],
),
# Case 3: Candidate has too high indeterminate percent -> gate fails
(
GateConfig(
max_f1_regression=0.1,
max_new_false_negatives=5,
max_unlabeled_percent=5,
),
MagicMock(f1_score=0.9, false_negatives=5, indeterminate_percent=2.0),
MagicMock(f1_score=0.85, false_negatives=5, indeterminate_percent=6.0),
["current indeterminate matches % is greater than"],
),
# Case 4: Candidate passes all thresholds -> gate passes (no reasons)
(
GateConfig(
max_f1_regression=0.1,
max_new_false_negatives=5,
max_unlabeled_percent=10,
),
MagicMock(f1_score=0.9, false_negatives=5, indeterminate_percent=2.0),
MagicMock(f1_score=0.85, false_negatives=5, indeterminate_percent=3.0),
[],
),
],
)
def test_gate(config, reference_summary, candidate_summary, expected_reasons):
"""Parameterized test for the Gate class that checks different pass/fail conditions."""

# Create the Gate instance with the given parameters
gate = Gate(
reference_comparison=reference_summary,
candidate_comparison=candidate_summary,
config=config,
input_description=MagicMock(image="test_image"),
)

# Check that the reasons list matches the expected outcome
assert len(gate.reasons) == len(expected_reasons)
for reason, expected_reason in zip(gate.reasons, expected_reasons):
assert expected_reason in reason


def test_gate_failing():
input_description = GateInputDescription(image="some-image", configs=[])
gate = Gate.failing(["sample failure reason"], input_description)
assert not gate.passed()
assert gate.reasons == ["sample failure reason"]


def test_gate_passing():
input_description = GateInputDescription(image="some-image", configs=[])
gate = Gate.passing(input_description)
assert gate.passed()

0 comments on commit 11c2fc0

Please sign in to comment.