Skip to content

Commit

Permalink
Support link prediction NDCG metric (#276)
Browse files Browse the repository at this point in the history
as titled.
  • Loading branch information
XinweiHe authored Nov 14, 2024
1 parent 00fb9d2 commit 500fd04
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
30 changes: 30 additions & 0 deletions relbench/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,33 @@ def link_prediction_map(
precision_mat = np.cumsum(pred_isin, axis=1) / (np.arange(eval_k) + 1)
maps = (precision_mat * pred_isin).sum(axis=1) / clipped_dst_count
return maps.mean()

def link_prediction_ndcg(
pred_isin: NDArray[np.int_],
dst_count: NDArray[np.int_],
) -> float:
pred_isin, dst_count = _filter(pred_isin, dst_count)
eval_k = pred_isin.shape[1]

# Compute the discounted multiplier (1 / log2(i + 2) for i = 0, ..., k-1)
discounted_multiplier = np.concatenate((
np.zeros(1),
1 / np.log2(np.arange(1, eval_k + 1) + 1)
))

# Compute Discounted Cumulative Gain (DCG)
discounted_cumulative_gain = (pred_isin * discounted_multiplier[1:eval_k + 1]).sum(axis=1)

# Clip dst_count to the range [0, eval_k]
clipped_dst_count = np.clip(dst_count, 0, eval_k)

# Compute Ideal Discounted Cumulative Gain (IDCG)
ideal_discounted_multiplier_cumsum = np.cumsum(discounted_multiplier)
ideal_discounted_cumulative_gain = ideal_discounted_multiplier_cumsum[clipped_dst_count]

# Avoid division by zero
ideal_discounted_cumulative_gain = np.clip(ideal_discounted_cumulative_gain, 1e-10, None)

# Compute NDCG
ndcg_scores = discounted_cumulative_gain / ideal_discounted_cumulative_gain
return ndcg_scores.mean()
3 changes: 3 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
link_prediction_map,
link_prediction_precision,
link_prediction_recall,
link_prediction_ndcg,
)


Expand All @@ -15,6 +16,8 @@ def test_link_prediction_metrics():
recall = link_prediction_recall(pred_isin, dst_count)
precision = link_prediction_precision(pred_isin, dst_count)
map = link_prediction_map(pred_isin, dst_count)
ndcg = link_prediction_ndcg(pred_isin, dst_count)
assert 0 <= recall <= 1
assert 0 <= precision <= 1
assert 0 <= map <= 1
assert 0 <= ndcg <= 1

0 comments on commit 500fd04

Please sign in to comment.