-
Notifications
You must be signed in to change notification settings - Fork 620
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into fix-negative-sampling-bug
- Loading branch information
Showing
3 changed files
with
34 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,10 @@ | |
# @Time : 2020/11/9 | ||
# @Author : Zihan Lin | ||
# @Email : [email protected] | ||
# UPDATE | ||
# @Time :2023/9/21 | ||
# @Author : Kesha Ou | ||
# @Email :[email protected] | ||
|
||
r""" | ||
Pop | ||
|
@@ -43,7 +47,7 @@ def calculate_loss(self, interaction): | |
|
||
self.max_cnt = torch.max(self.item_cnt, dim=0)[0] | ||
|
||
return torch.nn.Parameter(torch.zeros(1)) | ||
return torch.nn.Parameter(torch.zeros(1)).to(self.device) | ||
|
||
def predict(self, interaction): | ||
item = interaction[self.ITEM_ID] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,11 @@ | |
# @Author : Hui Wang | ||
# @Email : [email protected] | ||
|
||
# UPDATE | ||
# @Time : 2023/9/4 | ||
# @Author : Enze Liu | ||
# @Email : [email protected] | ||
|
||
r""" | ||
BERT4Rec | ||
################################################ | ||
|
@@ -75,6 +80,10 @@ def __init__(self, config, dataset): | |
|
||
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) | ||
self.dropout = nn.Dropout(self.hidden_dropout_prob) | ||
self.output_ffn = nn.Linear(self.hidden_size, self.hidden_size) | ||
self.output_gelu = nn.GELU() | ||
self.output_ln = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) | ||
self.output_bias = nn.Parameter(torch.zeros(self.n_items)) | ||
|
||
# we only need compute the loss at the masked position | ||
try: | ||
|
@@ -124,7 +133,9 @@ def forward(self, item_seq): | |
trm_output = self.trm_encoder( | ||
input_emb, extended_attention_mask, output_all_encoded_layers=True | ||
) | ||
output = trm_output[-1] | ||
ffn_output = self.output_ffn(trm_output[-1]) | ||
ffn_output = self.output_gelu(ffn_output) | ||
output = self.output_ln(ffn_output) | ||
return output # [B L H] | ||
|
||
def multi_hot_embed(self, masked_index, max_length): | ||
|
@@ -172,8 +183,14 @@ def calculate_loss(self, interaction): | |
if self.loss_type == "BPR": | ||
pos_items_emb = self.item_embedding(pos_items) # [B mask_len H] | ||
neg_items_emb = self.item_embedding(neg_items) # [B mask_len H] | ||
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B mask_len] | ||
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B mask_len] | ||
pos_score = ( | ||
torch.sum(seq_output * pos_items_emb, dim=-1) | ||
+ self.output_bias[pos_items] | ||
) # [B mask_len] | ||
neg_score = ( | ||
torch.sum(seq_output * neg_items_emb, dim=-1) | ||
+ self.output_bias[neg_items] | ||
) # [B mask_len] | ||
targets = (masked_index > 0).float() | ||
loss = -torch.sum( | ||
torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets | ||
|
@@ -183,8 +200,9 @@ def calculate_loss(self, interaction): | |
elif self.loss_type == "CE": | ||
loss_fct = nn.CrossEntropyLoss(reduction="none") | ||
test_item_emb = self.item_embedding.weight[: self.n_items] # [item_num H] | ||
logits = torch.matmul( | ||
seq_output, test_item_emb.transpose(0, 1) | ||
logits = ( | ||
torch.matmul(seq_output, test_item_emb.transpose(0, 1)) | ||
+ self.output_bias | ||
) # [B mask_len item_num] | ||
targets = (masked_index > 0).float().view(-1) # [B*mask_len] | ||
|
||
|
@@ -204,7 +222,9 @@ def predict(self, interaction): | |
seq_output = self.forward(item_seq) | ||
seq_output = self.gather_indexes(seq_output, item_seq_len - 1) # [B H] | ||
test_item_emb = self.item_embedding(test_item) | ||
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] | ||
scores = (torch.mul(seq_output, test_item_emb)).sum(dim=1) + self.output_bias[ | ||
test_item | ||
] # [B] | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
|
@@ -216,7 +236,7 @@ def full_sort_predict(self, interaction): | |
test_items_emb = self.item_embedding.weight[ | ||
: self.n_items | ||
] # delete masked token | ||
scores = torch.matmul( | ||
seq_output, test_items_emb.transpose(0, 1) | ||
scores = ( | ||
torch.matmul(seq_output, test_items_emb.transpose(0, 1)) + self.output_bias | ||
) # [B, item_num] | ||
return scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters