Skip to content

Commit

Permalink
Add an AUPRC metric to torchrec metrics (pytorch#1582)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1582

[APS][RecMetric] Add an AUPRC metric to torchrec metrics

Reviewed By: zainhuda

Differential Revision:
D48798090

Privacy Context Container: L1150871

fbshipit-source-id: ba8b7660b566686efebd1fa23eb25ee0f8b8c27f
  • Loading branch information
Mohammad Samavatian authored and facebook-github-bot committed Dec 14, 2023
1 parent 6b03397 commit decc6dd
Show file tree
Hide file tree
Showing 6 changed files with 714 additions and 1 deletion.
360 changes: 360 additions & 0 deletions torchrec/metrics/auprc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import Any, cast, Dict, List, Optional, Type

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)

PREDICTIONS = "predictions"
LABELS = "labels"
WEIGHTS = "weights"
GROUPING_KEYS = "grouping_keys"
REQUIRED_INPUTS = "required_inputs"


def _riemann_integral(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Riemann integral approximates the area of each cell with a rectangle positioned at the egde.
It is conventionally used rather than trapezoid approximation, which uses a rectangle positioned in the
center"""
return -torch.sum((x[1:] - x[:-1]) * y[:-1])


def _compute_auprc_helper(
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
sorted_indices = torch.argsort(predictions, descending=True, dim=-1)

threshold = torch.index_select(predictions, dim=0, index=sorted_indices)

sorted_labels = torch.index_select(labels, dim=0, index=sorted_indices)

sorted_weights = torch.index_select(weights, dim=0, index=sorted_indices)

mask = F.pad(threshold.diff(dim=0) != 0, [0, 1], value=1.0)
num_tp = torch.cumsum(sorted_weights * sorted_labels, dim=0)[mask]
num_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=0)[mask]

precision = (num_tp / (num_tp + num_fp)).flip(0)
recall = (num_tp / num_tp[-1]).flip(0)

# The last precision and recall values are 1.0 and 0.0 without a corresponding threshold.
# This ensures that the graph starts on the y-axis.
precision = torch.cat([precision, precision.new_ones(1)])
recall = torch.cat([recall, recall.new_zeros(1)])

# If recalls are NaNs, set NaNs to 1.0s.
if torch.isnan(recall[0]):
recall = torch.nan_to_num(recall, 1.0)

auprc = _riemann_integral(recall, precision)
return auprc


def compute_auprc(
n_tasks: int,
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""
Computes AUPRC (Area Under the Curve) for binary classification.
Args:
n_tasks (int): number of tasks.
predictions (torch.Tensor): tensor of size (n_tasks, n_examples).
labels (torch.Tensor): tensor of size (n_tasks, n_examples).
weights (torch.Tensor): tensor of size (n_tasks, n_examples).
"""
auprcs = []
for predictions_i, labels_i, weights_i in zip(predictions, labels, weights):
auprc = _compute_auprc_helper(predictions_i, labels_i, weights_i)
auprcs.append(auprc.view(1))
return torch.cat(auprcs)


def compute_auprc_per_group(
n_tasks: int,
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
grouping_keys: torch.Tensor,
) -> torch.Tensor:
"""
Computes AUPRC (Area Under the Curve) for binary classification for groups of predictions/labels.
Args:
n_tasks (int): number of tasks
predictions (torch.Tensor): tensor of size (n_tasks, n_examples)
labels (torch.Tensor): tensor of size (n_tasks, n_examples)
weights (torch.Tensor): tensor of size (n_tasks, n_examples)
grouping_keys (torch.Tensor): tensor of size (n_examples,)
Returns:
torch.Tensor: tensor of size (n_tasks,), average of AUPRCs per group.
"""
auprcs = []
if grouping_keys.numel() != 0 and grouping_keys[0] == -1:
# we added padding as the first elements during init to avoid floating point exception in sync()
# removing the paddings to avoid numerical errors.
grouping_keys = grouping_keys[1:]
predictions = predictions[:, 1:]
labels = labels[:, 1:]
weights = weights[:, 1:]

# get unique group indices
group_indices = torch.unique(grouping_keys)

for (predictions_i, labels_i, weights_i) in zip(predictions, labels, weights):
# Loop over each group
auprc_groups_sum = torch.tensor([0], dtype=torch.float32)
for group_idx in group_indices:
# get predictions, labels, and weights for this group
group_mask = grouping_keys == group_idx
grouped_predictions = predictions_i[group_mask]
grouped_labels = labels_i[group_mask]
grouped_weights = weights_i[group_mask]

auprc = _compute_auprc_helper(
grouped_predictions, grouped_labels, grouped_weights
)
auprc_groups_sum = auprc_groups_sum.to(auprc.device)
auprc_groups_sum += auprc.view(1)
avg_auprc = (
auprc_groups_sum / len(group_indices)
if len(group_indices) > 0
else torch.tensor([0.5], dtype=torch.float32)
)
auprcs.append(avg_auprc)
return torch.cat(auprcs)


def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]:
return [torch.cat(state, dim=dim)]


# pyre-ignore
_grouping_keys_state_reduction = partial(_state_reduction, dim=0)


class AUPRCMetricComputation(RecMetricComputation):
r"""
This class implements the RecMetricComputation for AUPRC, i.e. Area Under the Curve.
The constructor arguments are defined in RecMetricComputation.
See the docstring of RecMetricComputation for more detail.
Args:
grouped_auprc (bool): If True, computes AUPRC per group and returns average AUPRC across all groups.
The `grouping_keys` is provided during state updates along with predictions, labels, weights.
This feature is currently not enabled for `fused_update_limit`.
"""

def __init__(
self,
*args: Any,
grouped_auprc: bool = False,
fused_update_limit: int = 0,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
if grouped_auprc and fused_update_limit > 0:
raise RecMetricException(
"Grouped AUPRC and Fused Update Limit cannot be enabled together yet."
)

self._grouped_auprc: bool = grouped_auprc
self._add_state(
PREDICTIONS,
[],
add_window_state=False,
dist_reduce_fx=_state_reduction,
persistent=False,
)
self._add_state(
LABELS,
[],
add_window_state=False,
dist_reduce_fx=_state_reduction,
persistent=False,
)
self._add_state(
WEIGHTS,
[],
add_window_state=False,
dist_reduce_fx=_state_reduction,
persistent=False,
)
if self._grouped_auprc:
self._add_state(
GROUPING_KEYS,
[],
add_window_state=False,
dist_reduce_fx=_grouping_keys_state_reduction,
persistent=False,
)
self._init_states()

# The states values are set to empty lists in __init__() and reset(), and then we
# add a size (self._n_tasks, 1) tensor to each of the list as the initial values
# This is to bypass the limitation of state aggregation in TorchMetrics sync() when
# we try to checkpoint the states before update()
# The reason for using lists here is to avoid automatically stacking the tensors from
# all the trainers into one tensor in sync()
# The reason for using non-empty tensors as the first elements is to avoid the
# floating point exception thrown in sync() for aggregating empty tensors
def _init_states(self) -> None:
if len(getattr(self, PREDICTIONS)) > 0:
return

getattr(self, PREDICTIONS).append(
torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device)
)
getattr(self, LABELS).append(
torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device)
)
getattr(self, WEIGHTS).append(
torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device)
)
if self._grouped_auprc:
getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device))

def update(
self,
*,
predictions: Optional[torch.Tensor],
labels: torch.Tensor,
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
"""
Args:
predictions (torch.Tensor): tensor of size (n_task, n_examples)
labels (torch.Tensor): tensor of size (n_task, n_examples)
weights (torch.Tensor): tensor of size (n_task, n_examples)
grouping_key (torch.Tensor): Optional tensor of size (1, n_examples) that specifies the groups of
predictions/labels per batch. If provided, the PR AUC metric also
computes PR AUC per group and returns the average PR AUC across all groups.
"""
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' should not be None for AUPRCMetricComputation update"
)
predictions = predictions.float()
labels = labels.float()
weights = weights.float()
num_samples = getattr(self, PREDICTIONS)[0].size(-1)
batch_size = predictions.size(-1)
start_index = max(num_samples + batch_size - self._window_size, 0)
# Using `self.predictions =` will cause Pyre errors.
getattr(self, PREDICTIONS)[0] = torch.cat(
[
cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:],
predictions,
],
dim=-1,
)
getattr(self, LABELS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels],
dim=-1,
)
getattr(self, WEIGHTS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights],
dim=-1,
)
if self._grouped_auprc:
if REQUIRED_INPUTS not in kwargs or (
(grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None
):
raise RecMetricException(
f"Input '{GROUPING_KEYS}' are required for AUPRCMetricComputation grouped update"
)
getattr(self, GROUPING_KEYS)[0] = torch.cat(
[
cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0])[start_index:],
grouping_keys.squeeze(),
],
dim=0,
)

def _compute(self) -> List[MetricComputationReport]:
reports = [
MetricComputationReport(
name=MetricName.AUPRC,
metric_prefix=MetricPrefix.WINDOW,
value=compute_auprc(
self._n_tasks,
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
),
)
]
if self._grouped_auprc:
reports.append(
MetricComputationReport(
name=MetricName.GROUPED_AUPRC,
metric_prefix=MetricPrefix.WINDOW,
value=compute_auprc_per_group(
self._n_tasks,
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]),
),
)
)
return reports

def reset(self) -> None:
super().reset()
self._init_states()


class AUPRCMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.AUPRC
_computation_class: Type[RecMetricComputation] = AUPRCMetricComputation

def __init__(
self,
world_size: int,
my_rank: int,
batch_size: int,
tasks: List[RecTaskInfo],
compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION,
window_size: int = 100,
fused_update_limit: int = 0,
compute_on_all_ranks: bool = False,
should_validate_update: bool = False,
process_group: Optional[dist.ProcessGroup] = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
world_size=world_size,
my_rank=my_rank,
batch_size=batch_size,
tasks=tasks,
compute_mode=compute_mode,
window_size=window_size,
fused_update_limit=fused_update_limit,
compute_on_all_ranks=compute_on_all_ranks,
should_validate_update=should_validate_update,
process_group=process_group,
**kwargs,
)
if kwargs.get("grouped_auprc"):
self._required_inputs.add(GROUPING_KEYS)
2 changes: 2 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.profiler import record_function
from torchrec.metrics.accuracy import AccuracyMetric
from torchrec.metrics.auc import AUCMetric
from torchrec.metrics.auprc import AUPRCMetric
from torchrec.metrics.calibration import CalibrationMetric
from torchrec.metrics.ctr import CTRMetric
from torchrec.metrics.mae import MAEMetric
Expand Down Expand Up @@ -54,6 +55,7 @@
RecMetricEnum.CTR: CTRMetric,
RecMetricEnum.CALIBRATION: CalibrationMetric,
RecMetricEnum.AUC: AUCMetric,
RecMetricEnum.AUPRC: AUPRCMetric,
RecMetricEnum.MSE: MSEMetric,
RecMetricEnum.MAE: MAEMetric,
RecMetricEnum.MULTICLASS_RECALL: MulticlassRecallMetric,
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RecMetricEnum(RecMetricEnumBase):
LOG_LOSS = "log_loss"
CTR = "ctr"
AUC = "auc"
AUPRC = "auprc"
CALIBRATION = "calibration"
MSE = "mse"
MAE = "mae"
Expand Down
Loading

0 comments on commit decc6dd

Please sign in to comment.