-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metric refactor #69
Metric refactor #69
Conversation
…now baked into Metric class. User only needs to define __call__ method to calculate metric given inputs and targets. Also changed the return type of predict so that we can pass predictions and features in two seperate dictionairies. Small changes to examples to integrate in aforementioned changes
…ctionairies for predictions and features, bringing it in line with other methods
…tions and features. Also added some comments throughout
examples/fedopt_example/metrics.py
Outdated
self.n_classes: int | ||
self.outcome_dict: Dict[str, Outcome] | ||
|
||
def _setup(self, label_encoder: LabelEncoder) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically the _
prefix is reserved for protected or private methods. That is, methods that are exclusively called within the class itself, rather than externally. All that is to say, I would recommend dropping the _ based on the way this is being used 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good call! I initially was aiming for it to be internal but realized we have to call it externally so I forgot to change it back
|
||
Args: | ||
label_encoder (LabelEncoder): This class is used to determine the mapping of integers to label names for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe transfer this comment about label_encoder
to the setup method below?
examples/fedopt_example/client.py
Outdated
@@ -99,5 +86,5 @@ def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
# Load model and data | |||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
data_path = Path(args.dataset_path) | |||
client = NewsClassifierClient(data_path, DEVICE) | |||
client = NewsClassifierClient(data_path, [CompoundMetric("")], DEVICE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we give this metric a name?
fl4health/clients/basic_client.py
Outdated
"prediction": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter") | ||
} | ||
self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map) | ||
self.train_metric_meter_mngr = MetricManager(metrics=self.metrics, metric_mngr_name="train") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since these are now "MetricManager" objects, should we change the names from train_metric_meter_mngr
to train_metric_manager
and val_metric_meter_mngr
to val_metric_manager
throughout?
fl4health/utils/metrics.py
Outdated
Class to manage one or metric meters. | ||
""" | ||
Args: | ||
preds (Dict[str, torch.Tensor]): A dictionairy of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is incomplete. Also, I think there are a few places throughout the code where "dictionary" is spelled incorrectly in this way 😂
fl4health/clients/basic_client.py
Outdated
Returns: | ||
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element | ||
contains predictions indexed by name and the second element contains intermediate activations | ||
index by name. BY passing features, we can compute losses such as the model contrasting loss in MOON. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mean to capitalize BY, but just want to check
|
||
Returns: | ||
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element | ||
contains predictions indexed by name and the second element contains intermediate activations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably worth mentioning that anything stored in predictions will be used to compute metrics, that way people don't just store a bunch of stuff in there accidentally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you added a comment in the compute loss function about anything stored in preds being used to compute metrics, but I think we should also put that in the comment for the predict
function since that's where people will define what's in the predictions dictionary. So they "know" this before potentially stuffing extra things into that dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call!
fl4health/clients/basic_client.py
Outdated
In the default case, the dict has a single item with key prediction. | ||
In more complicated approaches such as APFL, the dict has as many items as prediction types | ||
User can override for more complex logic. | ||
Computes the prediction(s) (and potentially features) of the model(s) given the input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor, but I think you can drop the parentheses here and just write
Computes the prediction(s), and optionally features, of the model(s) given the input.
fl4health/clients/evaluate_client.py
Outdated
@@ -42,9 +41,9 @@ def __init__( | |||
self.data_loader: DataLoader | |||
self.criterion: _Loss | |||
self.global_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) | |||
self.global_metric_meter = MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "global_eval_meter") | |||
self.global_metric_meter = MetricManager(self.metrics, "global_eval_meter") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe change the names of these properties to be something like global_metric_manager
and local_metric_manager
?
fl4health/clients/fed_prox_client.py
Outdated
self, preds: Dict[str, torch.Tensor], features: Dict[str, torch.Tensor], target: torch.Tensor | ||
) -> Losses: | ||
""" | ||
Computes loss given predictions of the model and ground truth data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a description that we're also adding in the proximal loss, comparing the l2 norm between the initial and final weights of local training
fl4health/clients/fenda_client.py
Outdated
|
||
Returns: | ||
Losses: Object containing checkpoint loss, backward loss and additional losses indexed by name. | ||
Additional losses includes proximal loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No proximal loss in these calculations 🙂
fl4health/clients/moon_client.py
Outdated
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element | ||
contains predictions indexed by name and the second element contains intermediate activations | ||
index by name. Specificaly the features of the model, features of the global model and features of | ||
the old model are passed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor I would say are returned rather than "passed"
fl4health/model_bases/fenda_base.py
Outdated
"local_features": local_output.reshape(len(local_output), -1), | ||
"global_features": global_output.reshape(len(global_output), -1), | ||
} | ||
# Return preds and features as seperate dictionairy as in moon base |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seperate -> separate 🙂
fl4health/utils/metrics.py
Outdated
def clear(self) -> None: | ||
self.metric_values_history = [[] for _ in range(len(self.metrics))] | ||
self.counts = [] | ||
self.og_metrics = metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what the og
here stands for unless you meant "original gangster" metrics, which would be funny, but probably not necessary 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hahah not quite, its just a short form I have used for original but I can change it to avoid that interpretation
fl4health/utils/metrics.py
Outdated
Args: | ||
preds (Dict[str, torch.Tensor]): A dictionairy of | ||
""" | ||
if len(self.metrics_per_prediction_type) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super minor, but I think you can just do if self. metrics_per_prediction_type:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the case for dictionaries, just verified in python interpreter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh...I thought I verified it in an interpreter too 😂
>>> dict_empty = {}
>>> dict_filled = {"a": "b"}
>>> print("HI") if dict_empty else print("BYE")
BYE
>>> print("HI") if dict_filled else print("BYE")
HI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with leaving it though. It's very minor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry its cause if self. metrics_per_prediction_type:
should be if not self. metrics_per_prediction_type:
. My bad hahaha. Updated in my most recent commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right...also my bad for not writing the correct condition 🤦
fl4health/utils/metrics.py
Outdated
|
||
def __init__(self, key_to_meter_map: Dict[str, MetricMeter]): | ||
self.key_to_meter_map = key_to_meter_map | ||
for pred, mtrcs in zip(preds.values(), self.metrics_per_prediction_type.values()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit wary of doing the zip here, as it assumes that the keys of the dictionaries are ordered in the same way and are the same length and won't fail if that is not the case. Maybe we do something like
assert len(preds) == len(self.metrics_per_prediction_type)
for prediction_key, pred in prediction_keys.items():
metrics_for_prediction_type = self.metrics_per_prediction_type[prediction_key]
for metric_for_prediction_type in metrics_for_prediction_type:
metric_for_prediction_type.update(pred, target)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree this is a better way to go about it. Just one small issue, the length of the metric_per_prediction_type is going to as long as as the number of prediction types not the as the number of metrics. Thus I instead assert that the list of metrics at a given key is long as preds.
fl4health/utils/metrics.py
Outdated
for meter in self.key_to_meter_map.values(): | ||
result = meter.compute() | ||
all_results.update(result) | ||
for metrics_key, mtrcs in self.metrics_per_prediction_type.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor, but let's avoid the abbreviation here, since metrics
is pretty short anyway.
fl4health/utils/metrics.py
Outdated
self.metric_values_history = [[] for _ in range(len(self.metrics))] | ||
self.counts = [] | ||
self.og_metrics = metrics | ||
self.metric_mngr_name = metric_mngr_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any objection to expanding these to metric_manager_name
and metric_manager_name
, respectively since we're only saving three letters in the abbreviation anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few other places in the code where we use this abbreviation that I'd suggest we expand as well, unless you really don't like it 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think this is an awesome refactor. All the comments I left are quite minor and I didn't see anything major that was missed.
fl4health/clients/fed_prox_client.py
Outdated
@@ -98,7 +98,8 @@ def compute_loss( | |||
self, preds: Dict[str, torch.Tensor], features: Dict[str, torch.Tensor], target: torch.Tensor | |||
) -> Losses: | |||
""" | |||
Computes loss given predictions of the model and ground truth data. | |||
Computes loss given predictions of the model and ground truth data. Adds to objective by including | |||
proximal loss which is the L2 norm between the initial and final weights of local training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to be a pedantic mathematician and have you lower-case l2 here. L^2/L_2 norms operate on functions not vectors lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hahahah don't ever not be a pedantic mathematician, its good to know these notation conventions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good to me. Left two additional small comments. Feel free to take them or leave them
PR Type
[Feature | Fix | Documentation | Other() ]
Short Description
Clickup Story: Refactor Client Metrics
Refactor client metrics so that they maintain state, eliminating the need for MetricMeter's which tend to just make things more complicated. Also add the option to return features from models through predict method to be used in loss computation. Then apply this to the MOON client to compute contrastive loss. Added some documentation with proper formatting to Metric related code and a few client methods that I altered slightly in this PR. In a follow up PR, I will more extensively apply the proper formatting to client code and more broadly.
Note: This should only be reviewed once the Create Fixed Requirements File for FLamby, Update Dynamic Weight Exchanger and FedOpt Example PR is merged. I just wanted to base my PR on that branch to adapt David's CustomMetricMeter (and a few other relevant parts) to the simplified Metric tracking.
Tests Added