Skip to content

Commit

Permalink
Improve prefix handling in sequence generator
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Sep 28, 2023
1 parent 599f698 commit a1656ef
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/fairseq2/generation/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,18 +532,22 @@ def _bootstrap_seqs_and_scores(

model_output = self.decoder.project(decoder_output, decoder_padding_mask)

# lprobs: (S_pfx - 1, V)
# model_output: (N, S_pfx - 1, V) -> (S_pfx - 1, V)
lprobs = log_softmax(model_output.logits[0], dim=-1, dtype=torch.float32)
# lprobs: (N, S_pfx - 1, V)
# model_output: (N, S_pfx - 1, V) -> (N, S_pfx - 1, V)
lprobs = log_softmax(model_output.logits, dim=-1, dtype=torch.float32)

# (S_pfx) -> (1, S_pfx, 1)
indices = self.prefix_seq.unsqueeze(0).unsqueeze(2)

# (1, S_pfx, 1) -> (N, S_pfx, 1)
indices = indices.expand(lprobs.size(0), -1, -1)

# Fetch scores of next steps.
# (S_pfx - 1, 1)
prefix_scores = torch.take_along_dim(
lprobs, indices=self.prefix_seq[1:].unsqueeze(1), dim=-1
)
# (N, S_pfx - 1, 1)
prefix_scores = torch.gather(lprobs, dim=-1, index=indices[:, 1:])

# (S_pfx - 1, 1) -> (S_pfx - 1)
prefix_scores.squeeze_(1).cumsum_(dim=0)
# (N, S_pfx - 1, 1) -> (N, S_pfx - 1)
prefix_scores.squeeze_(-1).cumsum_(dim=-1)

# First step (e.g. EOS)'s score is always 0.
scores[:, 1 : self.prefix_seq_len] = prefix_scores
Expand Down

0 comments on commit a1656ef

Please sign in to comment.