From 8de4271cec4632d094a9f14afd1db6b034067e73 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 19 Oct 2023 14:52:30 -0700 Subject: [PATCH] Segmented NE Metric (#1446) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1446 Introduces Segemented NE metric, a very similar interface to grouped AUC, pass in an additional tensor of `grouping_keys`and computes NE for each of the labels. Labels can be up to N segments, this is determined by the `num_groups` passed into SegmentedNEMetric when it is instantiated, see as example: ``` ne = SegmentedNEMetric( world_size=1, my_rank=0, batch_size=batch_size, tasks=task_list, num_groups=2 ) ``` Returns NE by label (suffixed): ``` {'segmented_ne-Task:0|lifetime_segmented_ne_0': tensor(3.1615), 'segmented_ne-Task:0|lifetime_segmented_ne_1': tensor(1.6004)} ``` Reviewed By: paw-lu Differential Revision: D50107324 fbshipit-source-id: aafed21846b24d6dbfd916e9ca4925851429c115 --- torchrec/metrics/metric_module.py | 2 + torchrec/metrics/metrics_config.py | 1 + torchrec/metrics/metrics_namespace.py | 2 + torchrec/metrics/segmented_ne.py | 289 ++++++++++++++++++++ torchrec/metrics/tests/test_segmented_ne.py | 150 ++++++++++ 5 files changed, 444 insertions(+) create mode 100644 torchrec/metrics/segmented_ne.py create mode 100644 torchrec/metrics/tests/test_segmented_ne.py diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 44cb368d0..1a3e905c8 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -40,6 +40,7 @@ from torchrec.metrics.ne import NEMetric from torchrec.metrics.rec_metric import RecMetric, RecMetricList from torchrec.metrics.recall_session import RecallSessionMetric +from torchrec.metrics.segmented_ne import SegmentedNEMetric from torchrec.metrics.throughput import ThroughputMetric from torchrec.metrics.tower_qps import TowerQPSMetric from torchrec.metrics.weighted_avg import WeightedAvgMetric @@ -49,6 +50,7 @@ REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = { RecMetricEnum.NE: NEMetric, + RecMetricEnum.SEGMENTED_NE: SegmentedNEMetric, RecMetricEnum.CTR: CTRMetric, RecMetricEnum.CALIBRATION: CalibrationMetric, RecMetricEnum.AUC: AUCMetric, diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index a855f18a9..bd2f52a2d 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -18,6 +18,7 @@ class RecMetricEnumBase(StrValueMixin, Enum): class RecMetricEnum(RecMetricEnumBase): NE = "ne" + SEGMENTED_NE = "segmented_ne" LOG_LOSS = "log_loss" CTR = "ctr" AUC = "auc" diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index d81ec0ed6..3bac90a89 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -38,6 +38,7 @@ class MetricName(MetricNameBase): DEFAULT = "" NE = "ne" + SEGMENTED_NE = "segmented_ne" LOG_LOSS = "logloss" THROUGHPUT = "throughput" TOTAL_EXAMPLES = "total_examples" @@ -64,6 +65,7 @@ class MetricNamespace(MetricNamespaceBase): DEFAULT = "" NE = "ne" + SEGMENTED_NE = "segmented_ne" THROUGHPUT = "throughput" CTR = "ctr" CALIBRATION = "calibration" diff --git a/torchrec/metrics/segmented_ne.py b/torchrec/metrics/segmented_ne.py new file mode 100644 index 000000000..b2e21c199 --- /dev/null +++ b/torchrec/metrics/segmented_ne.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + +PREDICTIONS = "predictions" +LABELS = "labels" +WEIGHTS = "weights" +SEGMENTS = "segments" + + +def compute_cross_entropy( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy = -weights * labels * torch.log2(predictions) - weights * ( + 1.0 - labels + ) * torch.log2(1.0 - predictions) + return cross_entropy + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +def compute_ne_helper( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_sum / ce_norm + + +def compute_logloss( + ce_sum: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + # we utilize tensor broadcasting for operations + labels_sum = pos_labels + neg_labels + labels_sum.clamp_(min=eta) + return ce_sum / labels_sum + + +def compute_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + num_groups: int, + eta: float, +) -> torch.Tensor: + # size should be (num_groups) + result_ne = torch.zeros(num_groups) + for group in range(num_groups): + mean_label = pos_labels[group] / weighted_num_samples[group] + ce_norm = _compute_cross_entropy_norm( + mean_label, pos_labels[group], neg_labels[group], eta + ) + ne = ce_sum[group] / ce_norm + result_ne[group] = ne + + # ne indexed by group - tensor size (num_groups) + return result_ne + + +def get_segemented_ne_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + grouping_keys: torch.Tensor, + eta: float, + num_groups: int, +) -> Dict[str, torch.Tensor]: + groups = torch.unique(grouping_keys) + cross_entropy, weighted_num_samples, pos_labels, neg_labels = ( + torch.zeros(num_groups), + torch.zeros(num_groups), + torch.zeros(num_groups), + torch.zeros(num_groups), + ) + for group in groups: + group_mask = grouping_keys == group + + group_labels = labels[group_mask] + group_predictions = predictions[group_mask] + group_weights = weights[group_mask] + + ce_sum_group = torch.sum( + compute_cross_entropy( + labels=group_labels, + predictions=group_predictions, + weights=group_weights, + eta=eta, + ), + dim=-1, + ) + + weighted_num_samples_group = torch.sum(group_weights, dim=-1) + pos_labels_group = torch.sum(group_weights * group_labels, dim=-1) + neg_labels_group = torch.sum(group_weights * (1.0 - group_labels), dim=-1) + + cross_entropy[group] = ce_sum_group.item() + weighted_num_samples[group] = weighted_num_samples_group.item() + pos_labels[group] = pos_labels_group.item() + neg_labels[group] = neg_labels_group.item() + + # tensor size for each value is (num_groups) + return { + "cross_entropy_sum": cross_entropy, + "weighted_num_samples": weighted_num_samples, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + } + + +def _state_reduction_sum(state: torch.Tensor) -> torch.Tensor: + return state.sum(dim=0) + + +class SegmentedNEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Segmented NE, i.e. Normalized Entropy - for boolean labels. + + Only binary labels are currently supported (0s, 1s), NE is computed for each label, NE across the whole model output + can be done through the normal NE metric. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE. + """ + + def __init__( + self, + *args: Any, + include_logloss: bool = False, # TODO - include + num_groups: int = 1, + **kwargs: Any, + ) -> None: + self._include_logloss: bool = include_logloss + super().__init__(*args, **kwargs) + self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state + self._add_state( + "cross_entropy_sum", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self.eta = 1e-12 + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' and 'grouping_keys' should not be None for NEMetricComputation update" + ) + elif ( + "required_inputs" not in kwargs + or kwargs["required_inputs"].get("grouping_keys") is None + ): + raise RecMetricException( + f"Required inputs for SegmentedNEMetricComputation update should contain 'grouping_keys', got kwargs: {kwargs}" + ) + elif kwargs["required_inputs"]["grouping_keys"].dtype != torch.int64: + raise RecMetricException( + f"Grouping keys must have type torch.int64, got {kwargs['required_inputs']['grouping_keys'].dtype}." + ) + + grouping_keys = kwargs["required_inputs"]["grouping_keys"] + states = get_segemented_ne_states( + labels, + predictions, + weights, + grouping_keys, + eta=self.eta, + num_groups=self._num_groups, + ) + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + + def _compute(self) -> List[MetricComputationReport]: + reports = [] + computed_ne = compute_ne( + self.cross_entropy_sum[0], + self.weighted_num_samples[0], + self.pos_labels[0], + self.neg_labels[0], + num_groups=self._num_groups, + eta=self.eta, + ) + + for group in range(self._num_groups): + reports.append( + MetricComputationReport( + name=MetricName.SEGMENTED_NE, + metric_prefix=MetricPrefix.LIFETIME, + value=computed_ne[group], + description="_" + str(group), + ), + ) + + if self._include_logloss: + log_loss_groups = compute_logloss( + self.cross_entropy_sum[0], + self.pos_labels[0], + self.neg_labels[0], + eta=self.eta, + ) + + for group in range(self._num_groups): + reports.append( + MetricComputationReport( + name=MetricName.LOG_LOSS, + metric_prefix=MetricPrefix.LIFETIME, + value=log_loss_groups[group], + description="_" + str(group), + ) + ) + + return reports + + +class SegmentedNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.SEGMENTED_NE + _computation_class: Type[RecMetricComputation] = SegmentedNEMetricComputation diff --git a/torchrec/metrics/tests/test_segmented_ne.py b/torchrec/metrics/tests/test_segmented_ne.py new file mode 100644 index 000000000..939787e0d --- /dev/null +++ b/torchrec/metrics/tests/test_segmented_ne.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Dict, Iterable, Union + +import torch +from torch import no_grad +from torchrec.metrics.rec_metric import RecTaskInfo +from torchrec.metrics.segmented_ne import SegmentedNEMetric + + +class SegementedNEValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of AUC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @no_grad() + def _test_segemented_ne_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_ne: torch.Tensor, + grouping_keys: torch.Tensor, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + if grouping_keys is not None: + inputs["required_inputs"] = {"grouping_keys": grouping_keys} + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + ne = SegmentedNEMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1), + ) + # pyre-ignore + ne.update(**inputs) + actual_ne = ne.compute() + + for task_id, task in enumerate(task_list): + for label in [0, 1]: + cur_actual_ne = actual_ne[ + f"segmented_ne-{task.name}|lifetime_segmented_ne_{label}" + ] + cur_expected_ne = expected_ne[task_id][label] + + torch.testing.assert_close( + cur_actual_ne, + cur_expected_ne, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {cur_actual_ne}, Expected: {cur_expected_ne}", + ) + + def test_grouped_ne(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_segemented_ne_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + +def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: + return [ + # base condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + }, + # one sided, edge case 1s + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([1, 1, 1, 1, 1]), + "expected_ne": torch.tensor([[torch.nan, 1.3936]]), + }, + # one sided, edge case 0s + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 0, 0, 0, 0]), + "expected_ne": torch.tensor([[1.3936, torch.nan]]), + }, + # three labels, + { + "labels": torch.tensor([[1, 0, 0, 1, 1, 0]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9, 0.4]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75, 0.4]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 2, 2]), + "expected_ne": torch.tensor([[3.1615, 1.8311, 0.3814]]), + }, + # two tasks + { + "labels": torch.tensor([[1, 0, 0, 1, 1], [1, 0, 0, 1, 1]]), + "predictions": torch.tensor( + [ + [0.2, 0.6, 0.8, 0.4, 0.9], + [0.6, 0.2, 0.4, 0.8, 0.9], + ] + ), + "weights": torch.tensor( + [ + [0.13, 0.2, 0.5, 0.8, 0.75], + [0.13, 0.2, 0.5, 0.8, 0.75], + ] + ), + "grouping_keys": torch.tensor( + [0, 1, 0, 1, 1] + ), # for this case, both tasks have same groupings + "expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]), + }, + ]