diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 368545a51..688f04583 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -50,77 +50,64 @@ def _compute_auc_helper( def compute_auc( n_tasks: int, - predictions: List[torch.Tensor], - labels: List[torch.Tensor], - weights: List[torch.Tensor], + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, + apply_bin: bool = False, ) -> torch.Tensor: """ Computes AUC (Area Under the Curve) for binary classification. Args: n_tasks (int): number of tasks. - predictions (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples). - labels (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples). - weights (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples). + predictions (torch.Tensor): tensor of size (n_tasks, n_examples). + labels (torch.Tensor): tensor of size (n_tasks, n_examples). + weights (torch.Tensor): tensor of size (n_tasks, n_examples). """ - # concatenate tensors along dim = -1 - predictions_cat = torch.cat(predictions, dim=-1) - labels_cat = torch.cat(labels, dim=-1) - weights_cat = torch.cat(weights, dim=-1) - aucs = [] - for predictions_i, labels_i, weights_i in zip( - predictions_cat, labels_cat, weights_cat - ): - auc = _compute_auc_helper(predictions_i, labels_i, weights_i) + for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): + auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin) aucs.append(auc.view(1)) return torch.cat(aucs) def compute_auc_per_group( n_tasks: int, - predictions: List[torch.Tensor], - labels: List[torch.Tensor], - weights: List[torch.Tensor], - grouping_keys: List[torch.Tensor], + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, + grouping_keys: torch.Tensor, ) -> torch.Tensor: """ Computes AUC (Area Under the Curve) for binary classification for groups of predictions/labels. Args: n_tasks (int): number of tasks - predictions (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples) - labels (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples) - weights (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples) - grouping_keys (List[torch.Tensor]): List of tensors of size (n_examples,) + predictions (torch.Tensor): tensor of size (n_tasks, n_examples) + labels (torch.Tensor): tensor of size (n_tasks, n_examples) + weights (torch.Tensor): tensor of size (n_tasks, n_examples) + grouping_keys (torch.Tensor): tensor of size (n_examples,) Returns: torch.Tensor: tensor of size (n_tasks,), average of AUCs per group. """ - predictions_cat = torch.cat(predictions, dim=-1) - labels_cat = torch.cat(labels, dim=-1) - weights_cat = torch.cat(weights, dim=-1) - grouping_keys_cat = torch.cat(grouping_keys, dim=-1) - aucs = [] - if grouping_keys_cat.numel() != 0 and grouping_keys_cat[0] == -1: + if grouping_keys.numel() != 0 and grouping_keys[0] == -1: # we added padding as the first elements during init to avoid floating point exception in sync() # removing the paddings to avoid numerical errors. - grouping_keys_cat = grouping_keys_cat[1:] - predictions_cat = predictions_cat[:, 1:] - labels_cat = labels_cat[:, 1:] - weights_cat = weights_cat[:, 1:] + grouping_keys = grouping_keys[1:] + predictions = predictions[:, 1:] + labels = labels[:, 1:] + weights = weights[:, 1:] # get unique group indices - group_indices = torch.unique(grouping_keys_cat) + group_indices = torch.unique(grouping_keys) - for (predictions_i, labels_i, weights_i) in zip( - predictions_cat, labels_cat, weights_cat - ): + for (predictions_i, labels_i, weights_i) in zip(predictions, labels, weights): # Loop over each group auc_groups_sum = torch.tensor([0], dtype=torch.float32) for group_idx in group_indices: # get predictions, labels, and weights for this group - group_mask = grouping_keys_cat == group_idx + group_mask = grouping_keys == group_idx grouped_predictions = predictions_i[group_mask] grouped_labels = labels_i[group_mask] grouped_weights = weights_i[group_mask] @@ -254,12 +241,25 @@ def update( predictions = predictions.float() labels = labels.float() weights = weights.float() - + num_samples = getattr(self, PREDICTIONS)[0].size(-1) + batch_size = predictions.size(-1) + start_index = max(num_samples + batch_size - self._window_size, 0) # Using `self.predictions =` will cause Pyre errors. - getattr(self, PREDICTIONS).append(predictions) - getattr(self, LABELS).append(labels) - getattr(self, WEIGHTS).append(weights) - + getattr(self, PREDICTIONS)[0] = torch.cat( + [ + cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:], + predictions, + ], + dim=-1, + ) + getattr(self, LABELS)[0] = torch.cat( + [cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels], + dim=-1, + ) + getattr(self, WEIGHTS)[0] = torch.cat( + [cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights], + dim=-1, + ) if self._grouped_auc: if REQUIRED_INPUTS not in kwargs or ( (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None @@ -267,8 +267,13 @@ def update( raise RecMetricException( f"Input '{GROUPING_KEYS}' are required for AUCMetricComputation grouped update" ) - - getattr(self, GROUPING_KEYS).append(grouping_keys.squeeze()) + getattr(self, GROUPING_KEYS)[0] = torch.cat( + [ + cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0])[start_index:], + grouping_keys.squeeze(), + ], + dim=0, + ) def _compute(self) -> List[MetricComputationReport]: reports = [ @@ -277,12 +282,10 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.WINDOW, value=compute_auc( self._n_tasks, - # pyre-ignore[6] - cast(torch.Tensor, getattr(self, PREDICTIONS)), - # pyre-ignore[6] - cast(torch.Tensor, getattr(self, LABELS)), - # pyre-ignore[6] - cast(torch.Tensor, getattr(self, WEIGHTS)), + cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), + cast(torch.Tensor, getattr(self, LABELS)[0]), + cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + self._apply_bin, ), ) ] @@ -293,14 +296,10 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.WINDOW, value=compute_auc_per_group( self._n_tasks, - # pyre-ignore[6] - cast(torch.Tensor, getattr(self, PREDICTIONS)), - # pyre-ignore[6] - cast(torch.Tensor, getattr(self, LABELS)), - # pyre-ignore[6] - cast(torch.Tensor, getattr(self, WEIGHTS)), - # pyre-ignore[6] - cast(torch.Tensor, getattr(self, GROUPING_KEYS)), + cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), + cast(torch.Tensor, getattr(self, LABELS)[0]), + cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]), ), ) ) diff --git a/torchrec/metrics/tests/test_auc.py b/torchrec/metrics/tests/test_auc.py index 8de096dc3..e81922cd4 100644 --- a/torchrec/metrics/tests/test_auc.py +++ b/torchrec/metrics/tests/test_auc.py @@ -177,23 +177,6 @@ def test_calc_auc_balanced(self) -> None: actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) - def test_calc_multiple_updates(self) -> None: - expected_auc = torch.tensor([0.4464], dtype=torch.float) - # first batch - self.labels["DefaultTask"] = torch.tensor([1, 0, 0]) - self.predictions["DefaultTask"] = torch.tensor([0.2, 0.6, 0.8]) - self.weights["DefaultTask"] = torch.tensor([0.13, 0.2, 0.5]) - - self.auc.update(**self.batches) - # second batch - self.labels["DefaultTask"] = torch.tensor([1, 1]) - self.predictions["DefaultTask"] = torch.tensor([0.4, 0.9]) - self.weights["DefaultTask"] = torch.tensor([0.8, 0.75]) - - self.auc.update(**self.batches) - multiple_batch = self.auc.compute()["auc-DefaultTask|window_auc"] - torch.allclose(expected_auc, multiple_batch) - def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: return [ diff --git a/torchrec/metrics/tests/test_gpu.py b/torchrec/metrics/tests/test_gpu.py index 184ea5c12..9ca7f9ad4 100644 --- a/torchrec/metrics/tests/test_gpu.py +++ b/torchrec/metrics/tests/test_gpu.py @@ -48,9 +48,9 @@ def test_auc_reset(self) -> None: labels={"DefaultTask": model_output["label"]}, weights={"DefaultTask": model_output["weight"]}, ) - self.assertEqual(len(auc._metrics_computations[0].predictions), 2) - self.assertEqual(len(auc._metrics_computations[0].labels), 2) - self.assertEqual(len(auc._metrics_computations[0].weights), 2) + self.assertEqual(len(auc._metrics_computations[0].predictions), 1) + self.assertEqual(len(auc._metrics_computations[0].labels), 1) + self.assertEqual(len(auc._metrics_computations[0].weights), 1) self.assertEqual(auc._metrics_computations[0].predictions[0].device, device) self.assertEqual(auc._metrics_computations[0].labels[0].device, device) self.assertEqual(auc._metrics_computations[0].weights[0].device, device)