diff --git a/uer/targets/bilm_target.py b/uer/targets/bilm_target.py index a9c04edf..58ffb52a 100644 --- a/uer/targets/bilm_target.py +++ b/uer/targets/bilm_target.py @@ -26,9 +26,9 @@ def forward(self, memory_bank, tgt, seg): tgt_forward, tgt_backward = tgt[0], tgt[1] # Forward. loss_forward, correct_forward, _ = \ - self.lm(memory_bank[:, :, :self.hidden_size], tgt_forward) + self.lm(memory_bank[:, :, :self.hidden_size], tgt_forward, seg) # Backward. loss_backward, correct_backward, denominator_backward = \ - self.lm(memory_bank[:, :, self.hidden_size:], tgt_backward) + self.lm(memory_bank[:, :, self.hidden_size:], tgt_backward, seg) return loss_forward, loss_backward, correct_forward, correct_backward, denominator_backward