Skip to content

Commit

Permalink
Fix view op in relative attention (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Sep 28, 2023
1 parent c5f0d7e commit 599f698
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/fairseq2/nn/transformer/relative_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def _compute_r(self, k: Tensor, batch_size: int) -> Tensor:
# (2 x S - 1, K) -> (2 x S - 1, K)
r = self.r_proj(r)

# (2 x S - 1, K) -> (N, 2 x S - 1, H, K_h)
r = r.view(batch_size, -1, self.num_heads, k.size(-1))
# (2 x S - 1, K) -> (1, 2 x S - 1, H, K_h)
r = r.view(1, -1, self.num_heads, k.size(-1))

# (N, 2 x S - 1, H, K_h) -> (N, H, 2 x S - 1, K_h)
r = r.transpose(1, 2)
# (1, 2 x S - 1, H, K_h) -> (N, H, 2 x S - 1, K_h)
r = r.transpose(1, 2).expand(batch_size, -1, -1, -1)

return r # type: ignore[no-any-return]

Expand Down

0 comments on commit 599f698

Please sign in to comment.