Skip to content

Commit

Permalink
[js/webgpu] Optimize MultiHeadAttention|Transpose (#22420)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
With this optimization, 96 MultiHeadAttention|Transpose ops in phi3
disappear. Phi3 becomes 113 tokens from 107 tokens on my dGPUs.

The optimization mainly skips the transpose op if one of the transposed
dims is 1. Reshape is enough.
  • Loading branch information
qjia7 authored Oct 14, 2024
1 parent de93f40 commit 0409c63
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = (
if (input.dims.length === 3) {
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
}
if (numHeads === 1 || sequenceLength === 1) {
return reshapedInput;
}
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
inputs: [reshapedInput],
outputs: [-1],
Expand All @@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = (
biasOffset!,
);
reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
if (numHeads === 1 || sequenceLength === 1) {
return reshapedInput;
}
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
inputs: [reshapedInput],
outputs: [-1],
Expand Down

0 comments on commit 0409c63

Please sign in to comment.