-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metric refactor #69
Metric refactor #69
Changes from 9 commits
e912d4d
05e68b7
bb0b0ba
fcae1a3
a0d1a9a
0280e17
b8c0fd7
6f66a10
767d641
5231ec6
a20508b
8ecb8ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
import json | ||
from logging import INFO | ||
from typing import Dict, List | ||
from typing import Dict, List, Optional | ||
|
||
import torch | ||
from flwr.common.logger import log | ||
from flwr.common.typing import Metrics | ||
from sklearn.metrics import confusion_matrix | ||
|
||
from examples.fedopt_example.client_data import LabelEncoder | ||
from fl4health.utils.metrics import MetricMeter | ||
from fl4health.utils.metrics import Metric | ||
|
||
|
||
class Outcome: | ||
|
@@ -75,20 +75,26 @@ def compute_metrics(self) -> Metrics: | |
return metrics | ||
|
||
|
||
class CustomMetricMeter(MetricMeter): | ||
def __init__(self, label_encoder: LabelEncoder) -> None: | ||
class CompoundMetric(Metric): | ||
def __init__(self, name: str) -> None: | ||
""" | ||
This class is used to compute metrics associated with the AG's News task. There are a number of classes and | ||
we want to accumulate a bunch of statistics all at once to facilitate the computation of a number of different | ||
metrics for this problem. As such, we define our own MetricMeter and bypass the standard metric meter | ||
implementations, which calculate separate metrics individually. | ||
metrics for this problem. As such, we define our own Metric class and bypass the standard SimpleMetric class, | ||
which calculate separate metrics individually. | ||
|
||
Args: | ||
label_encoder (LabelEncoder): This class is used to determine the mapping of integers to label names for | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe transfer this comment about |
||
the AG's news task. | ||
name (str): The name of the compound metric. | ||
""" | ||
super().__init__(name) | ||
self.true_preds = 0 | ||
self.total_preds = 0 | ||
self.classes: List[str] | ||
self.label_to_class: Dict[int, str] | ||
self.n_classes: int | ||
self.outcome_dict: Dict[str, Outcome] | ||
|
||
def _setup(self, label_encoder: LabelEncoder) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typically the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, good call! I initially was aiming for it to be internal but realized we have to call it externally so I forgot to change it back |
||
self.classes = label_encoder.classes | ||
self.outcome_dict = self._initialize_outcomes(self.classes) | ||
self.label_to_class = label_encoder.label_to_class | ||
|
@@ -117,7 +123,7 @@ def update(self, input: torch.Tensor, target: torch.Tensor) -> None: | |
self.outcome_dict[true_class].false_negative += count | ||
self.outcome_dict[pred_class].false_positive += count | ||
|
||
def compute(self) -> Metrics: | ||
def compute(self, name: Optional[str]) -> Metrics: | ||
sum_f1 = 0.0 | ||
results: Metrics = {"total_preds": self.total_preds, "true_preds": self.true_preds} | ||
log_string = "" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we give this metric a name?