Skip to content

Commit

Permalink
Merge branch 'master' into fix-negative-sampling-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Ethan-TZ authored Oct 8, 2023
2 parents f103428 + 6d9f742 commit dae9b5f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
6 changes: 5 additions & 1 deletion recbole/model/general_recommender/pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
# @Time : 2020/11/9
# @Author : Zihan Lin
# @Email : [email protected]
# UPDATE
# @Time :2023/9/21
# @Author : Kesha Ou
# @Email :[email protected]

r"""
Pop
Expand Down Expand Up @@ -43,7 +47,7 @@ def calculate_loss(self, interaction):

self.max_cnt = torch.max(self.item_cnt, dim=0)[0]

return torch.nn.Parameter(torch.zeros(1))
return torch.nn.Parameter(torch.zeros(1)).to(self.device)

def predict(self, interaction):
item = interaction[self.ITEM_ID]
Expand Down
36 changes: 28 additions & 8 deletions recbole/model/sequential_recommender/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
# @Author : Hui Wang
# @Email : [email protected]

# UPDATE
# @Time : 2023/9/4
# @Author : Enze Liu
# @Email : [email protected]

r"""
BERT4Rec
################################################
Expand Down Expand Up @@ -75,6 +80,10 @@ def __init__(self, config, dataset):

self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)
self.output_ffn = nn.Linear(self.hidden_size, self.hidden_size)
self.output_gelu = nn.GELU()
self.output_ln = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.output_bias = nn.Parameter(torch.zeros(self.n_items))

# we only need compute the loss at the masked position
try:
Expand Down Expand Up @@ -124,7 +133,9 @@ def forward(self, item_seq):
trm_output = self.trm_encoder(
input_emb, extended_attention_mask, output_all_encoded_layers=True
)
output = trm_output[-1]
ffn_output = self.output_ffn(trm_output[-1])
ffn_output = self.output_gelu(ffn_output)
output = self.output_ln(ffn_output)
return output # [B L H]

def multi_hot_embed(self, masked_index, max_length):
Expand Down Expand Up @@ -172,8 +183,14 @@ def calculate_loss(self, interaction):
if self.loss_type == "BPR":
pos_items_emb = self.item_embedding(pos_items) # [B mask_len H]
neg_items_emb = self.item_embedding(neg_items) # [B mask_len H]
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B mask_len]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B mask_len]
pos_score = (
torch.sum(seq_output * pos_items_emb, dim=-1)
+ self.output_bias[pos_items]
) # [B mask_len]
neg_score = (
torch.sum(seq_output * neg_items_emb, dim=-1)
+ self.output_bias[neg_items]
) # [B mask_len]
targets = (masked_index > 0).float()
loss = -torch.sum(
torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets
Expand All @@ -183,8 +200,9 @@ def calculate_loss(self, interaction):
elif self.loss_type == "CE":
loss_fct = nn.CrossEntropyLoss(reduction="none")
test_item_emb = self.item_embedding.weight[: self.n_items] # [item_num H]
logits = torch.matmul(
seq_output, test_item_emb.transpose(0, 1)
logits = (
torch.matmul(seq_output, test_item_emb.transpose(0, 1))
+ self.output_bias
) # [B mask_len item_num]
targets = (masked_index > 0).float().view(-1) # [B*mask_len]

Expand All @@ -204,7 +222,9 @@ def predict(self, interaction):
seq_output = self.forward(item_seq)
seq_output = self.gather_indexes(seq_output, item_seq_len - 1) # [B H]
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
scores = (torch.mul(seq_output, test_item_emb)).sum(dim=1) + self.output_bias[
test_item
] # [B]
return scores

def full_sort_predict(self, interaction):
Expand All @@ -216,7 +236,7 @@ def full_sort_predict(self, interaction):
test_items_emb = self.item_embedding.weight[
: self.n_items
] # delete masked token
scores = torch.matmul(
seq_output, test_items_emb.transpose(0, 1)
scores = (
torch.matmul(seq_output, test_items_emb.transpose(0, 1)) + self.output_bias
) # [B, item_num]
return scores
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ colorlog==4.7.2
colorama==0.4.4
tensorboard>=2.5.0
thop>=0.1.1.post2207130030
ray>=1.13.0
ray>=1.13.0, <=2.6.3
tabulate>=0.8.10
plotly>=4.0.0
texttable>=0.9.0

0 comments on commit dae9b5f

Please sign in to comment.