forked from hidasib/GRU4Rec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
158 lines (147 loc) · 7.46 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 26 17:27:26 2015
@author: Balázs Hidasi
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from itertools import izip
import numpy as np
import pandas as pd
def evaluate_sessions_batch(pr, test_data, items=None, cut_off=20, batch_size=100, break_ties=False, session_key='SessionId', item_key='ItemId', time_key='Time'):
'''
Evaluates the GRU4Rec network wrt. recommendation accuracy measured by recall@N and MRR@N.
Parameters
--------
pr : gru4rec.GRU4Rec
A trained instance of the GRU4Rec network.
test_data : pandas.DataFrame
Test data. It contains the transactions of the test set.It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps).
It must have a header. Column names are arbitrary, but must correspond to the keys you use in this function.
items : 1D list or None
The list of item ID that you want to compare the score of the relevant item to. If None, all items of the training set are used. Default value is None.
cut-off : int
Cut-off value (i.e. the length of the recommendation list; N for recall@N and MRR@N). Defauld value is 20.
batch_size : int
Number of events bundled into a batch during evaluation. Speeds up evaluation. If it is set high, the memory consumption increases. Default value is 100.
break_ties : boolean
Whether to add a small random number to each prediction value in order to break up possible ties, which can mess up the evaluation.
Defaults to False, because (1) GRU4Rec usually does not produce ties, except when the output saturates; (2) it slows down the evaluation.
Set to True is you expect lots of ties.
session_key : string
Header of the session ID column in the input file (default: 'SessionId')
item_key : string
Header of the item ID column in the input file (default: 'ItemId')
time_key : string
Header of the timestamp column in the input file (default: 'Time')
Returns
--------
out : tuple
(Recall@N, MRR@N)
'''
pr.predict = None #In case someone would try to run with both items=None and not None on the same model without realizing that the predict function needs to be replaced
test_data.sort_values([session_key, time_key], inplace=True)
offset_sessions = np.zeros(test_data[session_key].nunique()+1, dtype=np.int32)
offset_sessions[1:] = test_data.groupby(session_key).size().cumsum()
evalutation_point_count = 0
mrr, recall = 0.0, 0.0
if len(offset_sessions) - 1 < batch_size:
batch_size = len(offset_sessions) - 1
iters = np.arange(batch_size).astype(np.int32)
maxiter = iters.max()
start = offset_sessions[iters]
end = offset_sessions[iters+1]
in_idx = np.zeros(batch_size, dtype=np.int32)
np.random.seed(42)
while True:
valid_mask = iters >= 0
if valid_mask.sum() == 0:
break
start_valid = start[valid_mask]
minlen = (end[valid_mask]-start_valid).min()
in_idx[valid_mask] = test_data[item_key].values[start_valid]
for i in xrange(minlen-1):
out_idx = test_data[item_key].values[start_valid+i+1]
if items is not None:
uniq_out = np.unique(np.array(out_idx, dtype=np.int32))
preds = pr.predict_next_batch(iters, in_idx, np.hstack([items, uniq_out[~np.in1d(uniq_out,items)]]), batch_size)
else:
preds = pr.predict_next_batch(iters, in_idx, None, batch_size)
if break_ties:
preds += np.random.rand(*preds.values.shape) * 1e-8
preds.fillna(0, inplace=True)
in_idx[valid_mask] = out_idx
if items is not None:
others = preds.ix[items].values.T[valid_mask].T
targets = np.diag(preds.ix[in_idx].values)[valid_mask]
ranks = (others > targets).sum(axis=0) +1
else:
ranks = (preds.values.T[valid_mask].T > np.diag(preds.ix[in_idx].values)[valid_mask]).sum(axis=0) + 1
rank_ok = ranks < cut_off
recall += rank_ok.sum()
mrr += (1.0 / ranks[rank_ok]).sum()
evalutation_point_count += len(ranks)
start = start+minlen-1
mask = np.arange(len(iters))[(valid_mask) & (end-start<=1)]
for idx in mask:
maxiter += 1
if maxiter >= len(offset_sessions)-1:
iters[idx] = -1
else:
iters[idx] = maxiter
start[idx] = offset_sessions[maxiter]
end[idx] = offset_sessions[maxiter+1]
return recall/evalutation_point_count, mrr/evalutation_point_count
def evaluate_sessions(pr, test_data, train_data, items=None, cut_off=20, session_key='SessionId', item_key='ItemId', time_key='Time'):
'''
Evaluates the baselines wrt. recommendation accuracy measured by recall@N and MRR@N. Has no batch evaluation capabilities. Breaks up ties.
Parameters
--------
pr : baseline predictor
A trained instance of a baseline predictor.
test_data : pandas.DataFrame
Test data. It contains the transactions of the test set.It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps).
It must have a header. Column names are arbitrary, but must correspond to the keys you use in this function.
train_data : pandas.DataFrame
Training data. Only required for selecting the set of item IDs of the training set.
items : 1D list or None
The list of item ID that you want to compare the score of the relevant item to. If None, all items of the training set are used. Default value is None.
cut-off : int
Cut-off value (i.e. the length of the recommendation list; N for recall@N and MRR@N). Defauld value is 20.
session_key : string
Header of the session ID column in the input file (default: 'SessionId')
item_key : string
Header of the item ID column in the input file (default: 'ItemId')
time_key : string
Header of the timestamp column in the input file (default: 'Time')
Returns
--------
out : tuple
(Recall@N, MRR@N)
'''
test_data.sort_values([session_key, time_key], inplace=True)
items_to_predict = train_data[item_key].unique()
evalutation_point_count = 0
prev_iid, prev_sid = -1, -1
mrr, recall = 0.0, 0.0
for i in xrange(len(test_data)):
sid = test_data[session_key].values[i]
iid = test_data[item_key].values[i]
if prev_sid != sid:
prev_sid = sid
else:
if items is not None:
if np.in1d(iid, items): items_to_predict = items
else: items_to_predict = np.hstack(([iid], items))
preds = pr.predict_next(sid, prev_iid, items_to_predict)
preds[np.isnan(preds)] = 0
preds += 1e-8 * np.random.rand(len(preds)) #Breaking up ties
rank = (preds > preds[iid]).sum()+1
assert rank > 0
if rank < cut_off:
recall += 1
mrr += 1.0/rank
evalutation_point_count += 1
prev_iid = iid
return recall/evalutation_point_count, mrr/evalutation_point_count