diff --git a/recbole/model/sequential_recommender/repeatnet.py b/recbole/model/sequential_recommender/repeatnet.py index 161463252..cc7acf8d3 100644 --- a/recbole/model/sequential_recommender/repeatnet.py +++ b/recbole/model/sequential_recommender/repeatnet.py @@ -250,7 +250,9 @@ def forward(self, all_memory, last_memory, item_seq, mask=None): output_er = nn.Softmax(dim=-1)(output_er) batch_size, b_len = item_seq.size() - repeat_recommendation_decoder = torch.zeros([batch_size, self.num_item], device=self.device) + repeat_recommendation_decoder = torch.zeros( + [batch_size, self.num_item], device=self.device + ) repeat_recommendation_decoder.scatter_add_(1, item_seq, output_er) return repeat_recommendation_decoder.to(self.device) @@ -301,7 +303,9 @@ def forward(self, all_memory, last_memory, item_seq, mask=None): item_seq_first = item_seq[:, 0].unsqueeze(1).expand_as(item_seq) item_seq_first = item_seq_first.masked_fill(item_seq > 0, 0) item_seq_first.requires_grad_(False) - output_e.scatter_add_(1, item_seq + item_seq_first, float('-inf') * torch.ones_like(item_seq)) + output_e.scatter_add_( + 1, item_seq + item_seq_first, float("-inf") * torch.ones_like(item_seq) + ) explore_recommendation_decoder = nn.Softmax(1)(output_e) return explore_recommendation_decoder