Skip to content

Commit

Permalink
Improve performance of MHA (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Sep 21, 2023
1 parent 173c7ae commit a012f91
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 214 deletions.
4 changes: 2 additions & 2 deletions src/fairseq2/nn/incremental_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def increment_step(self, delta: int = 1) -> None:
"""Increment the step.
This method should be called after every incremental evaluation (e.g.
beam search). It is used by modules to keep track of the position in
the sequence.
beam search) step. It is used by modules to keep track of the position
in the sequence.
:param delta:
The value by which to increment the step.
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/nn/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,11 @@ def _naive_scaled_dot_product_attention(
needs_weights: bool,
training: bool,
) -> Tuple[Tensor, Optional[Tensor]]:
queries = queries * (queries.size(-1) ** -0.5)

# (*, S, K) @ (*, K, S_kv) = (*, S, S_kv)
attn_weights = torch.matmul(queries, keys.transpose(-1, -2))

attn_weights = attn_weights * (queries.size(-1) ** -0.5)

if mask is not None:
attn_weights = attn_weights + mask

Expand Down
Loading

0 comments on commit a012f91

Please sign in to comment.