diff --git a/src/fairseq2/generation/sequence_generator.py b/src/fairseq2/generation/sequence_generator.py index b229450fa..6f81f5cd7 100644 --- a/src/fairseq2/generation/sequence_generator.py +++ b/src/fairseq2/generation/sequence_generator.py @@ -532,18 +532,22 @@ def _bootstrap_seqs_and_scores( model_output = self.decoder.project(decoder_output, decoder_padding_mask) - # lprobs: (S_pfx - 1, V) - # model_output: (N, S_pfx - 1, V) -> (S_pfx - 1, V) - lprobs = log_softmax(model_output.logits[0], dim=-1, dtype=torch.float32) + # lprobs: (N, S_pfx - 1, V) + # model_output: (N, S_pfx - 1, V) -> (N, S_pfx - 1, V) + lprobs = log_softmax(model_output.logits, dim=-1, dtype=torch.float32) + + # (S_pfx) -> (1, S_pfx, 1) + indices = self.prefix_seq.unsqueeze(0).unsqueeze(2) + + # (1, S_pfx, 1) -> (N, S_pfx, 1) + indices = indices.expand(lprobs.size(0), -1, -1) # Fetch scores of next steps. - # (S_pfx - 1, 1) - prefix_scores = torch.take_along_dim( - lprobs, indices=self.prefix_seq[1:].unsqueeze(1), dim=-1 - ) + # (N, S_pfx - 1, 1) + prefix_scores = torch.gather(lprobs, dim=-1, index=indices[:, 1:]) - # (S_pfx - 1, 1) -> (S_pfx - 1) - prefix_scores.squeeze_(1).cumsum_(dim=0) + # (N, S_pfx - 1, 1) -> (N, S_pfx - 1) + prefix_scores.squeeze_(-1).cumsum_(dim=-1) # First step (e.g. EOS)'s score is always 0. scores[:, 1 : self.prefix_seq_len] = prefix_scores