Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce weighted LinkPredMetric interface #9943

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LinkPredMetric(BaseMetric):
is_differentiable: bool = False
full_state_update: bool = False
higher_is_better: Optional[bool] = None
weighted: bool = False

def __init__(self, k: int) -> None:
super().__init__()
Expand Down Expand Up @@ -93,6 +94,7 @@ def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
r"""Updates the state variables based on the current mini-batch
prediction.
Expand All @@ -108,7 +110,14 @@ def update(
edge_label_index (torch.Tensor): The ground-truth indices for every
example in the mini-batch, given in COO format of shape
:obj:`[2, num_ground_truth_indices]`.
edge_label_weight (torch.Tensor, optional): The weight of the
ground-truth indices for every example in the mini-batch of
shape :obj:`[num_ground_truth_indices]`. Required for
weighted metrics and ignored otherwise. (default: :obj:`None`)
"""
if self.weighted and edge_label_weight is None:
raise ValueError("'edge_label_weight' required for {self}")

pred_isin_mat, y_count = self._prepare(pred_index_mat,
edge_label_index)
self._update_from_prepared(pred_isin_mat, y_count)
Expand Down Expand Up @@ -206,6 +215,7 @@ def update( # type: ignore
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
r"""Updates the state variables based on the current mini-batch
prediction.
Expand All @@ -221,6 +231,10 @@ def update( # type: ignore
edge_label_index (torch.Tensor): The ground-truth indices for every
example in the mini-batch, given in COO format of shape
:obj:`[2, num_ground_truth_indices]`.
edge_label_weight (torch.Tensor, optional): The weight of the
ground-truth indices for every example in the mini-batch of
shape :obj:`[num_ground_truth_indices]`. Required for
weighted metrics and ignored otherwise. (default: :obj:`None`)
"""
pred_isin_mat, y_count = LinkPredMetric._prepare(
pred_index_mat, edge_label_index)
Expand Down Expand Up @@ -248,6 +262,7 @@ class LinkPredPrecision(LinkPredMetric):
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False

def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
return pred_isin_mat.sum(dim=-1) / self.k
Expand All @@ -260,6 +275,7 @@ class LinkPredRecall(LinkPredMetric):
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False

def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
return pred_isin_mat.sum(dim=-1) / y_count.clamp(min=1e-7)
Expand All @@ -272,6 +288,7 @@ class LinkPredF1(LinkPredMetric):
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False

def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
isin_count = pred_isin_mat.sum(dim=-1)
Expand All @@ -288,6 +305,7 @@ class LinkPredMAP(LinkPredMetric):
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False

def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
device = pred_isin_mat.device
Expand All @@ -305,6 +323,7 @@ class LinkPredNDCG(LinkPredMetric):
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False

def __init__(self, k: int):
super().__init__(k=k)
Expand Down Expand Up @@ -336,6 +355,7 @@ class LinkPredMRR(LinkPredMetric):
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False

def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
device = pred_isin_mat.device
Expand Down
Loading