Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric refactor #69

Merged
merged 12 commits into from
Nov 21, 2023
Merged

Metric refactor #69

merged 12 commits into from
Nov 21, 2023

Conversation

jewelltaylor
Copy link
Collaborator

@jewelltaylor jewelltaylor commented Nov 4, 2023

PR Type

[Feature | Fix | Documentation | Other() ]

Short Description

Clickup Story: Refactor Client Metrics

Refactor client metrics so that they maintain state, eliminating the need for MetricMeter's which tend to just make things more complicated. Also add the option to return features from models through predict method to be used in loss computation. Then apply this to the MOON client to compute contrastive loss. Added some documentation with proper formatting to Metric related code and a few client methods that I altered slightly in this PR. In a follow up PR, I will more extensively apply the proper formatting to client code and more broadly.

Note: This should only be reviewed once the Create Fixed Requirements File for FLamby, Update Dynamic Weight Exchanger and FedOpt Example PR is merged. I just wanted to base my PR on that branch to adapt David's CustomMetricMeter (and a few other relevant parts) to the simplified Metric tracking.

Tests Added

  • Adapted Metric related tests that were affected in the refactor
  • Add a test for the updated MetricMeter

…now baked into Metric class. User only needs to define __call__ method to calculate metric given inputs and targets. Also changed the return type of predict so that we can pass predictions and features in two seperate dictionairies. Small changes to examples to integrate in aforementioned changes
@jewelltaylor
Copy link
Collaborator Author

Now that PR #68 is approved and will presumably be merged into main with minimal changes, I think this PR is ready to review. I had to make some additional changes to update the parts of the code affected by Sana's recent PR #66. I would appreciate feedback that anyone has!

self.n_classes: int
self.outcome_dict: Dict[str, Outcome]

def _setup(self, label_encoder: LabelEncoder) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically the _ prefix is reserved for protected or private methods. That is, methods that are exclusively called within the class itself, rather than externally. All that is to say, I would recommend dropping the _ based on the way this is being used 🙂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good call! I initially was aiming for it to be internal but realized we have to call it externally so I forgot to change it back


Args:
label_encoder (LabelEncoder): This class is used to determine the mapping of integers to label names for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe transfer this comment about label_encoder to the setup method below?

@@ -99,5 +86,5 @@ def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:
# Load model and data
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = NewsClassifierClient(data_path, DEVICE)
client = NewsClassifierClient(data_path, [CompoundMetric("")], DEVICE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we give this metric a name?

"prediction": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter")
}
self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map)
self.train_metric_meter_mngr = MetricManager(metrics=self.metrics, metric_mngr_name="train")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these are now "MetricManager" objects, should we change the names from train_metric_meter_mngr to train_metric_manager and val_metric_meter_mngr to val_metric_manager throughout?

Class to manage one or metric meters.
"""
Args:
preds (Dict[str, torch.Tensor]): A dictionairy of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is incomplete. Also, I think there are a few places throughout the code where "dictionary" is spelled incorrectly in this way 😂

Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element
contains predictions indexed by name and the second element contains intermediate activations
index by name. BY passing features, we can compute losses such as the model contrasting loss in MOON.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean to capitalize BY, but just want to check


Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element
contains predictions indexed by name and the second element contains intermediate activations
Copy link
Collaborator

@emersodb emersodb Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably worth mentioning that anything stored in predictions will be used to compute metrics, that way people don't just store a bunch of stuff in there accidentally?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you added a comment in the compute loss function about anything stored in preds being used to compute metrics, but I think we should also put that in the comment for the predict function since that's where people will define what's in the predictions dictionary. So they "know" this before potentially stuffing extra things into that dictionary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call!

In the default case, the dict has a single item with key prediction.
In more complicated approaches such as APFL, the dict has as many items as prediction types
User can override for more complex logic.
Computes the prediction(s) (and potentially features) of the model(s) given the input.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor, but I think you can drop the parentheses here and just write

Computes the prediction(s), and optionally features, of the model(s) given the input.

@@ -42,9 +41,9 @@ def __init__(
self.data_loader: DataLoader
self.criterion: _Loss
self.global_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)
self.global_metric_meter = MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "global_eval_meter")
self.global_metric_meter = MetricManager(self.metrics, "global_eval_meter")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change the names of these properties to be something like global_metric_manager and local_metric_manager?

self, preds: Dict[str, torch.Tensor], features: Dict[str, torch.Tensor], target: torch.Tensor
) -> Losses:
"""
Computes loss given predictions of the model and ground truth data.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a description that we're also adding in the proximal loss, comparing the l2 norm between the initial and final weights of local training


Returns:
Losses: Object containing checkpoint loss, backward loss and additional losses indexed by name.
Additional losses includes proximal loss.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No proximal loss in these calculations 🙂

Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element
contains predictions indexed by name and the second element contains intermediate activations
index by name. Specificaly the features of the model, features of the global model and features of
the old model are passed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor I would say are returned rather than "passed"

"local_features": local_output.reshape(len(local_output), -1),
"global_features": global_output.reshape(len(global_output), -1),
}
# Return preds and features as seperate dictionairy as in moon base
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seperate -> separate 🙂

def clear(self) -> None:
self.metric_values_history = [[] for _ in range(len(self.metrics))]
self.counts = []
self.og_metrics = metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the og here stands for unless you meant "original gangster" metrics, which would be funny, but probably not necessary 😂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hahah not quite, its just a short form I have used for original but I can change it to avoid that interpretation

Args:
preds (Dict[str, torch.Tensor]): A dictionairy of
"""
if len(self.metrics_per_prediction_type) == 0:
Copy link
Collaborator

@emersodb emersodb Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super minor, but I think you can just do if self. metrics_per_prediction_type:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the case for dictionaries, just verified in python interpreter

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh...I thought I verified it in an interpreter too 😂

>>> dict_empty = {}
>>> dict_filled = {"a": "b"}
>>> print("HI") if dict_empty else print("BYE")
BYE
>>> print("HI") if dict_filled else print("BYE")
HI

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with leaving it though. It's very minor

Copy link
Collaborator Author

@jewelltaylor jewelltaylor Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry its cause if self. metrics_per_prediction_type: should be if not self. metrics_per_prediction_type:. My bad hahaha. Updated in my most recent commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right...also my bad for not writing the correct condition 🤦


def __init__(self, key_to_meter_map: Dict[str, MetricMeter]):
self.key_to_meter_map = key_to_meter_map
for pred, mtrcs in zip(preds.values(), self.metrics_per_prediction_type.values()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit wary of doing the zip here, as it assumes that the keys of the dictionaries are ordered in the same way and are the same length and won't fail if that is not the case. Maybe we do something like

assert len(preds) == len(self.metrics_per_prediction_type)
for prediction_key, pred in prediction_keys.items():
    metrics_for_prediction_type = self.metrics_per_prediction_type[prediction_key]
    for metric_for_prediction_type in metrics_for_prediction_type:
        metric_for_prediction_type.update(pred, target)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree this is a better way to go about it. Just one small issue, the length of the metric_per_prediction_type is going to as long as as the number of prediction types not the as the number of metrics. Thus I instead assert that the list of metrics at a given key is long as preds.

for meter in self.key_to_meter_map.values():
result = meter.compute()
all_results.update(result)
for metrics_key, mtrcs in self.metrics_per_prediction_type.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor, but let's avoid the abbreviation here, since metrics is pretty short anyway.

self.metric_values_history = [[] for _ in range(len(self.metrics))]
self.counts = []
self.og_metrics = metrics
self.metric_mngr_name = metric_mngr_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any objection to expanding these to metric_manager_name and metric_manager_name , respectively since we're only saving three letters in the abbreviation anyway?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few other places in the code where we use this abbreviation that I'd suggest we expand as well, unless you really don't like it 🙂

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think this is an awesome refactor. All the comments I left are quite minor and I didn't see anything major that was missed.

@@ -98,7 +98,8 @@ def compute_loss(
self, preds: Dict[str, torch.Tensor], features: Dict[str, torch.Tensor], target: torch.Tensor
) -> Losses:
"""
Computes loss given predictions of the model and ground truth data.
Computes loss given predictions of the model and ground truth data. Adds to objective by including
proximal loss which is the L2 norm between the initial and final weights of local training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to be a pedantic mathematician and have you lower-case l2 here. L^2/L_2 norms operate on functions not vectors lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hahahah don't ever not be a pedantic mathematician, its good to know these notation conventions

@emersodb emersodb self-requested a review November 21, 2023 14:03
Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me. Left two additional small comments. Feel free to take them or leave them

@jewelltaylor jewelltaylor merged commit bc143bc into main Nov 21, 2023
2 checks passed
@jewelltaylor jewelltaylor deleted the metric-refactor branch November 21, 2023 14:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants