Skip to content

Commit

Permalink
Add hyperparameter explanation. Fix some minor errors. Add new models…
Browse files Browse the repository at this point in the history
… into test code
  • Loading branch information
ken77921 committed Nov 24, 2023
1 parent 4b72c3a commit a4e5795
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 102 deletions.
63 changes: 13 additions & 50 deletions recbole/model/sequential_recommender/gru4reccpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
# @Author : Yupeng Hou, Yujie Lu
# @Email : [email protected], [email protected]

# UPDATE:
# @Time : 2023/11/24
# @Author : Haw-Shiuan Chang
# @Email : [email protected]

r"""
GRU4Rec + Softmax-CPR
################################################
Expand All @@ -27,6 +32,7 @@
import math
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss
import sys

def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Expand Down Expand Up @@ -64,26 +70,13 @@ def __init__(self, config, dataset):
self.n_facet_window = - self.n_facet_window
self.n_facet_MLP = - self.n_facet_MLP
self.softmax_nonlinear='None' #added for mfs
self.use_att = config['use_att'] #added for mfs
self.use_out_emb = config['use_out_emb'] #added for mfs
self.only_compute_loss = True #added for mfs
if self.use_att:
assert self.use_out_emb
self.dropout = nn.Dropout(self.dropout_prob)
self.We = nn.Linear(self.hidden_size, self.hidden_size)
self.Ue = nn.Linear(self.hidden_size, self.hidden_size)
self.tanh = nn.Tanh()
self.Ve = nn.Linear(self.hidden_size, 1)
out_size = 2*self.hidden_size
else:
self.dense = nn.Linear(self.hidden_size, self.embedding_size)
out_size = self.embedding_size

self.dense = nn.Linear(self.hidden_size, self.embedding_size)
out_size = self.embedding_size

self.n_embd = out_size
#if self.use_att:
# self.n_embd = 2* self.hidden_size #added for mfs
#else:
# self.n_embd = self.hidden_size #added for mfs

self.use_proj_bias = config['use_proj_bias'] #added for mfs
self.weight_mode = config['weight_mode'] #added for mfs
Expand All @@ -103,12 +96,8 @@ def __init__(self, config, dataset):
assert self.n_facet_emb == 0 or self.n_facet_emb == 2


if self.n_facet_MLP > 0:
hidden_state_input_ratio = 1 + self.n_facet_MLP #1 + 1
#print(self.n_embd, self.n_facet_hidden, self.n_facet_window, self.n_facet_MLP)
self.MLP_linear = nn.Linear(self.n_embd * (self.n_facet_hidden * (self.n_facet_window+1) ), self.n_embd * self.n_facet_MLP) # (hid_dim*2) -> (hid_dim)
else:
hidden_state_input_ratio = self.n_facet_hidden * (self.n_facet_window+1) #1 * (0+1)
hidden_state_input_ratio = 1 + self.n_facet_MLP #1 + 1
self.MLP_linear = nn.Linear(self.n_embd * (self.n_facet_hidden * (self.n_facet_window+1) ), self.n_embd * self.n_facet_MLP) # (hid_dim*2) -> (hid_dim)
total_lin_dim = self.n_embd * hidden_state_input_ratio
self.project_arr = nn.ModuleList([nn.Linear(total_lin_dim, self.n_embd, bias=self.use_proj_bias) for i in range(self.n_facet_all)])

Expand Down Expand Up @@ -159,33 +148,7 @@ def forward(self, item_seq, item_seq_len):
item_seq_emb = self.item_embedding(item_seq)
item_seq_emb_dropout = self.emb_dropout(item_seq_emb)
gru_output, _ = self.gru_layers(item_seq_emb_dropout)
if self.use_att:
all_memory = gru_output
all_memory_values = all_memory
bsz, seq_len, hsz = all_memory.size()

all_memory_U = self.dropout(self.Ue(all_memory))
all_memory_U = all_memory_U.unsqueeze(2)
all_memory_U = all_memory_U.expand(bsz, seq_len, all_memory.size(1), hsz)

all_memory_W = self.dropout(self.We(all_memory))
all_memory_W = all_memory_U.unsqueeze(1)
all_memory_W = all_memory_U.expand(bsz, all_memory.size(1), seq_len, hsz)

output_ee = self.tanh(all_memory_U + all_memory_W)
output_ee = self.Ve(output_ee).squeeze(-1)

timeline_mask = (item_seq == 0)
output_ee.masked_fill_(timeline_mask.unsqueeze(2).expand(bsz, seq_len, all_memory.size(1)), -1e9)

output_ee = output_ee.unsqueeze(-1)

alpha_e = nn.Softmax(dim=1)(output_ee)
alpha_e = alpha_e.expand(bsz, seq_len, seq_len, self.hidden_size)
output_e = (alpha_e * all_memory_values.unsqueeze(2).expand(bsz, seq_len, all_memory.size(1), hsz)).sum(dim=1)
gru_output = torch.cat([output_e, all_memory_values], dim=-1)
else:
gru_output = self.dense(gru_output)
gru_output = self.dense(gru_output)
return gru_output

def get_facet_emb(self,input_emb, i):
Expand Down Expand Up @@ -287,7 +250,7 @@ def calculate_loss_prob(self, interaction, only_compute_prob=False):
bsz, seq_len_2 = item_seq.size()
logit_hidden_context = (projected_emb_arr[self.n_facet + self.n_facet_reranker*len(self.reranker_CAN_NUM) + i].unsqueeze(dim=2).expand(-1,-1,seq_len_2,-1) * test_item_emb[item_seq, :].unsqueeze(dim=1).expand(-1,seq_len_1,-1,-1) ).sum(dim=-1)
if test_item_bias is not None:
logit_hidden_reranker_topn += test_item_bias[item_seq].unsqueeze(dim=1).expand(-1,seq_len_1,-1)
logit_hidden_context += test_item_bias[item_seq].unsqueeze(dim=1).expand(-1,seq_len_1,-1)
logit_hidden_pointer = 0
if self.n_facet_emb == 2:
logit_hidden_pointer = ( projected_emb_arr[-2][:,-1,:].unsqueeze(dim=1).unsqueeze(dim=1).expand(-1,seq_len_1,seq_len_2,-1) * projected_emb_arr[-1].unsqueeze(dim=1).expand(-1,seq_len_1,-1,-1) ).sum(dim=-1)
Expand Down
12 changes: 7 additions & 5 deletions recbole/model/sequential_recommender/sasreccpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
# @Author : Hui Wang
# @Email : [email protected]

# UPDATE:
# @Time : 2023/11/24
# @Author : Haw-Shiuan Chang
# @Email : [email protected]

"""
SASRec + Softmax-CPR
################################################
Expand Down Expand Up @@ -81,11 +86,8 @@ def __init__(self, config, dataset):
self.n_facet_MLP = - self.n_facet_MLP
self.softmax_nonlinear='None' #added for mfs
self.use_proj_bias = config['use_proj_bias'] #added for mfs
if self.n_facet_MLP > 0:
hidden_state_input_ratio = 1 + self.n_facet_MLP #1 + 1
self.MLP_linear = nn.Linear(self.hidden_size * (self.n_facet_hidden * (self.n_facet_window+1) ), self.hidden_size * self.n_facet_MLP) # (hid_dim*2) -> (hid_dim)
else:
hidden_state_input_ratio = self.n_facet_hidden * (self.n_facet_window+1) #1 * (0+1)
hidden_state_input_ratio = 1 + self.n_facet_MLP #1 + 1
self.MLP_linear = nn.Linear(self.hidden_size * (self.n_facet_hidden * (self.n_facet_window+1) ), self.hidden_size * self.n_facet_MLP) # (hid_dim*2) -> (hid_dim)
total_lin_dim = self.hidden_size * hidden_state_input_ratio
self.project_arr = nn.ModuleList([nn.Linear(total_lin_dim, self.hidden_size, bias=self.use_proj_bias) for i in range(self.n_facet_all)])

Expand Down
49 changes: 27 additions & 22 deletions recbole/properties/model/GRU4RecCPR.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
embedding_size: 64 # (int) The embedding size of items.
hidden_size: 128 # (int) The number of features in the hidden state.
num_layers: 1 # (int) The number of layers in GRU.
dropout_prob: 0 # (float) The dropout rate.
loss_type: 'CE' # (str) The type of loss function. This value can only be 'CE'.
embedding_size: 64 # (int) The embedding size of items.
hidden_size: 128 # (int) The number of features in the hidden state.
num_layers: 1 # (int) The number of layers in GRU.
dropout_prob: 0 # (float) The dropout rate.
loss_type: 'CE' # (str) The type of loss function. This value can only be 'CE'.

use_att: False
use_out_emb: False
n_facet: 1
n_facet_all: 5
n_facet_hidden: 1
n_facet_window: -2
n_facet_MLP: -1
n_facet_context: 1
n_facet_reranker: 1
n_facet_emb: 2
weight_mode: ''
context_norm: 1
post_remove_context: 0
partition_merging_mode: 'replace'
reranker_merging_mode: 'replace'
reranker_CAN_NUM: 100
use_proj_bias: 1
#Please see https://github.com/iesl/softmax_CPR_recommend/blob/master/run_hyper_loop.sh and [1] to see some common configuration of the following hyperparameters
use_out_emb: False # (bool) If False, we share the output item embedding and input item embedding ([2] shows that the sharing can encourage the item repetition)
n_facet_all: 5 # (int) Number of linear layers for context partition, reranker partition, pointer network, and most items in the vocabulary. Notice that n_facet_all = n_facet + n_facet_context + n_facet_reranker*len(reranker_CAN_NUM_arr) + n_facet_emb
n_facet: 1 # (int) Number of the output hidden states for most items in the vocabulary. If n_facet > 1, we will use mixture of softmax (MoS)
n_facet_context: 1 # (int) Number of the output hidden states for the context partition. This number should be either 0, 1 or n_facet (If you use MoS).
n_facet_reranker: 1 # (int) Number of the output hidden states for a single reranker partition. This number should be either 0, 1 or n_facet (If you use MoS).
reranker_CAN_NUM: 100 # (str) The size of reranker partitions. If you want to use 3 reranker partitions with size 500, 100, and 20, set "500,100,20". Notice that the number should have a descent order (e.g., setting it to 20,100,500 is incorrect).
n_facet_emb: 2 # (int) Number of the output hidden states for pointer network. This number should be either 0 or 2.
n_facet_hidden: 1 # (int) min(n_facet_hidden, num_layers) = H hyperparameter in multiple input hidden states (Mi) [3]. If not using Mi, set this number to 1.
n_facet_window: -2 # (int) -n_facet_window + 1 is the W hyperparameter in multiple input hidden states [3]. If not using Mi, set this number to 0.
n_facet_MLP: -1 # (int) The dimension of q_ct in [3] is (-n_facet_MLP + 1)*embedding_size. If not using Mi, set this number to 0.
weight_mode: '' # (str) The method of merging probability distribution in MoS. The value could be "dynamic" [4], "static", and "max_logits" [1].
context_norm: 1 # (int) If setting 0, we remove the denominator in Equation (5) of [1].
partition_merging_mode: 'replace' # (str) If "replace", the logit from context partition and pointer network would overwrite the logit from reranker partition and original softmax. Otherwise, the logit would be added.
reranker_merging_mode: 'replace' # (str) If "add", the logit from reranker partition would be added with the original softmax. Otherwise, the softmax logit would be replaced by the logit from reranker partition.
use_proj_bias: 1 # (bool) In linear layers for all output hidden states, if we want to use the bias term.
post_remove_context: 0 # (int) Setting the probability of all the items in the history to be 0 [2].

#[1] Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum. "To Copy, or not to Copy; That is a Critical Issue of the Output Softmax Layer in Neural Sequential Recommenders." In Proceedings of The 17th ACM Inernational Conference on Web Search and Data Mining (WSDM 24)
#[2] Ming Li, Ali Vardasbi, Andrew Yates, and Maarten de Rijke. 2023. Repetition and Exploration in Sequential Recommendation. In SIGIR 2023: 46th international ACM SIGIR Conference on Research and Development in Information Retrieval. ACM, 2532–2541.
#[3] Haw-Shiuan Chang and Andrew McCallum. 2022. Softmax bottleneck makes language models unable to represent multi-mode word distributions. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 8048–8073
#[4] Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, and William W. Cohen. "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model." In International Conference on Learning Representations. 2018.
56 changes: 31 additions & 25 deletions recbole/properties/model/SASRecCPR.yaml
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
n_layers: 2 # (int) The number of transformer layers in transformer encoder.
n_heads: 2 # (int) The number of attention heads for multi-head attention layer.
hidden_size: 64 # (int) The number of features in the hidden state.
inner_size: 256 # (int) The inner hidden size in feed-forward layer.
hidden_dropout_prob: 0 # (float) The probability of an element to be zeroed.
attn_dropout_prob: 0.1 # (float) The probability of an attention score to be zeroed.
hidden_act: 'gelu' # (str) The activation function in feed-forward layer.
layer_norm_eps: 1e-12 # (float) A value added to the denominator for numerical stability.
initializer_range: 0.02 # (float) The standard deviation for normal initialization.
loss_type: 'CE' # (str) The type of loss function. This value can only be 'CE'.
n_layers: 2 # (int) The number of transformer layers in transformer encoder.
n_heads: 2 # (int) The number of attention heads for multi-head attention layer.
hidden_size: 64 # (int) The number of features in the hidden state.
inner_size: 256 # (int) The inner hidden size in feed-forward layer.
hidden_dropout_prob: 0 # (float) The probability of an element to be zeroed.
attn_dropout_prob: 0.1 # (float) The probability of an attention score to be zeroed.
hidden_act: 'gelu' # (str) The activation function in feed-forward layer.
layer_norm_eps: 1e-12 # (float) A value added to the denominator for numerical stability.
initializer_range: 0.02 # (float) The standard deviation for normal initialization.
loss_type: 'CE' # (str) The type of loss function. This value can only be 'CE'.

n_facet: 1
n_facet_all: 5
n_facet_hidden: 2
n_facet_window: -2
n_facet_MLP: -1
n_facet_context: 1
n_facet_reranker: 1
n_facet_emb: 2
weight_mode: ''
context_norm: 1
post_remove_context: 0
partition_merging_mode: 'replace'
reranker_merging_mode: 'replace'
reranker_CAN_NUM: 100
use_proj_bias: 1
#Please see https://github.com/iesl/softmax_CPR_recommend/blob/master/run_hyper_loop.sh and [1] to see some common configuration of the following hyperparameters
n_facet_all: 5 # (int) Number of linear layers for context partition, reranker partition, pointer network, and most items in the vocabulary. Notice that n_facet_all = n_facet + n_facet_context + n_facet_reranker*len(reranker_CAN_NUM_arr) + n_facet_emb
n_facet: 1 # (int) Number of the output hidden states for most items in the vocabulary. If n_facet > 1, we will use mixture of softmax (MoS)
n_facet_context: 1 # (int) Number of the output hidden states for the context partition. This number should be either 0, 1 or n_facet (If you use MoS).
n_facet_reranker: 1 # (int) Number of the output hidden states for a single reranker partition. This number should be either 0, 1 or n_facet (If you use MoS).
reranker_CAN_NUM: 100 # (str) The size of reranker partitions. If you want to use 3 reranker partitions with size 500, 100, and 20, set "500,100,20". Notice that the number should have a descent order (e.g., setting it to 20,100,500 is incorrect).
n_facet_emb: 2 # (int) Number of the output hidden states for pointer network. This number should be either 0 or 2.
n_facet_hidden: 2 # (int) min(n_facet_hidden, n_layers) = H hyperparameter in multiple input hidden states (Mi) [3]. If not using Mi, set this number to 1.
n_facet_window: -2 # (int) -n_facet_window + 1 is the W hyperparameter in multiple input hidden states [3]. If not using Mi, set this number to 0.
n_facet_MLP: -1 # (int) The dimension of q_ct in [3] is (-n_facet_MLP + 1)*embedding_size. If not using Mi, set this number to 0.
weight_mode: '' # (str) The method of merging probability distribution in MoS. The value could be "dynamic" [4], "static", and "max_logits" [1].
context_norm: 1 # (int) If setting 0, we remove the denominator in Equation (5) of [1].
partition_merging_mode: 'replace' # (str) If "replace", the logit from context partition and pointer network would overwrite the logit from reranker partition and original softmax. Otherwise, the logit would be added.
reranker_merging_mode: 'replace' # (str) If "add", the logit from reranker partition would be added with the original softmax. Otherwise, the softmax logit would be replaced by the logit from reranker partition.
use_proj_bias: 1 # (bool) In linear layers for all output hidden states, if we want to use the bias term.
post_remove_context: 0 # (int) Setting the probability of all the items in the history to be 0 [2].

#[1] Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum. "To Copy, or not to Copy; That is a Critical Issue of the Output Softmax Layer in Neural Sequential Recommenders." In Proceedings of The 17th ACM Inernational Conference on Web Search and Data Mining (WSDM 24)
#[2] Ming Li, Ali Vardasbi, Andrew Yates, and Maarten de Rijke. 2023. Repetition and Exploration in Sequential Recommendation. In SIGIR 2023: 46th international ACM SIGIR Conference on Research and Development in Information Retrieval. ACM, 2532–2541.
#[3] Haw-Shiuan Chang and Andrew McCallum. 2022. Softmax bottleneck makes language models unable to represent multi-mode word distributions. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 8048–8073
#[4] Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, and William W. Cohen. "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model." In International Conference on Learning Representations. 2018.
8 changes: 8 additions & 0 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,10 @@ def test_fpmc(self):
def test_gru4rec(self):
config_dict = {"model": "GRU4Rec", "train_neg_sample_args": None}
quick_test(config_dict)

def test_gru4reccpr(self):
config_dict = {"model": "GRU4RecCPR", "train_neg_sample_args": None}
quick_test(config_dict)

def test_gru4rec_with_BPR_loss(self):
config_dict = {
Expand Down Expand Up @@ -531,6 +535,10 @@ def test_transrec(self):
def test_sasrec(self):
config_dict = {"model": "SASRec", "train_neg_sample_args": None}
quick_test(config_dict)

def test_sasreccpr(self):
config_dict = {"model": "SASRecCPR", "train_neg_sample_args": None}
quick_test(config_dict)

def test_sasrec_with_BPR_loss_and_relu(self):
config_dict = {"model": "SASRec", "loss_type": "BPR", "hidden_act": "relu"}
Expand Down

0 comments on commit a4e5795

Please sign in to comment.