Skip to content

Commit

Permalink
More efficient way to compute weighted IDCG in LinkPredNDCG (#9946)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rusty1s and pre-commit-ci[bot] authored Jan 15, 2025
1 parent 371678c commit 8db0702
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
12 changes: 8 additions & 4 deletions test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,13 @@ def test_map():

def test_ndcg():
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]])
edge_label_weight = torch.tensor([1.0, 2.0, 3.0, 0.5])
edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])
edge_label_weight = torch.tensor([1.0, 2.0, 0.1, 3.0, 0.5])

metric = LinkPredNDCG(k=2)
assert str(metric) == 'LinkPredNDCG(k=2)'
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()

assert float(result) == pytest.approx(0.6934264)

# Test with `k > pred_index_mat.size(1)`:
Expand All @@ -134,9 +133,14 @@ def test_ndcg():

metric.update(pred_index_mat, edge_label_index, edge_label_weight)
result = metric.compute()

metric.reset()
assert float(result) == pytest.approx(0.7854486)

perm = torch.randperm(edge_label_weight.size(0))
metric.update(pred_index_mat, edge_label_index[:, perm],
edge_label_weight[perm])
assert metric.compute() == result

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight)
metric.compute()
Expand Down
45 changes: 23 additions & 22 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from torch import Tensor

from torch_geometric.index import index2ptr
from torch_geometric.utils import cumsum, scatter

try:
Expand Down Expand Up @@ -425,27 +424,29 @@ def _compute(self, data: LinkPredMetricData) -> Tensor:
idcg = self.idcg[data.label_count.clamp(max=self.k)]
else:
assert data.edge_label_weight is not None
# Sort weights in buckets via two sorts:
weight, perm = data.edge_label_weight.sort(descending=True)
batch = data.edge_label_index[0][perm]
batch, perm = torch.sort(batch, stable=True)
weight = weight[perm]

# Shrink buckets that are larger than `k`:
arange = torch.arange(batch.size(0), device=batch.device)
ptr = index2ptr(batch, size=data.pred_index_mat.size(0))
batched_arange = arange - ptr[batch]
mask = batched_arange < self.k
batch = batch[mask]
batched_arange = batched_arange[mask]
weight = weight[mask]

# Compute ideal relevance matrix:
irel_mat = weight.new_zeros(data.pred_index_mat.size(0) * self.k)
irel_mat[batch * self.k + batched_arange] = weight
irel_mat = irel_mat.view(-1, self.k)

idcg = (irel_mat / self.discount.view(1, -1)).sum(dim=-1)
# Sort weights within example-wise buckets via two sorts to get the
# local index order within buckets:
weight, batch = data.edge_label_weight, data.edge_label_index[0]
perm1 = weight.argsort(descending=True)
perm2 = batch[perm1].argsort(stable=True)
global_index = torch.empty_like(perm1)
global_index[perm1[perm2]] = torch.arange(
global_index.size(0), device=global_index.device)
local_index = global_index - cumsum(data.label_count)[batch]

# Get the discount per local index:
discount = torch.cat([
self.discount,
self.discount.new_full((1, ), fill_value=float('inf')),
])
discount = discount[local_index.clamp(max=self.k + 1)]

idcg = scatter( # Apply discount and aggregate:
weight / discount,
batch,
dim_size=data.pred_index_mat.size(0),
reduce='sum',
)

out = dcg / idcg
out[out.isnan() | out.isinf()] = 0.0
Expand Down

0 comments on commit 8db0702

Please sign in to comment.