[webgpu] Flash attention for generation#23808
Merged
Conversation
1. Only copy the new kv data for static kv cache 2. Add flash decoding for sequence_length = 1
6f6d6d1 to
f0424fd
Compare
Contributor
|
can you merge with main? |
Contributor
Author
Done. This PR is ready for review. Thanks. |
sushraja-msft
requested changes
Mar 26, 2025
qjia7
commented
Mar 27, 2025
Contributor
Author
qjia7
left a comment
There was a problem hiding this comment.
Rename valid_new_present_shape to copy_kv_shape to help understand. Thanks for your suggestion.
sushraja-msft
approved these changes
Apr 4, 2025
sushraja-msft
previously approved these changes
Apr 4, 2025
guschmue
previously approved these changes
Apr 4, 2025
Contributor
|
can you merge with main? |
qjia7
commented
Apr 7, 2025
Contributor
Author
qjia7
left a comment
There was a problem hiding this comment.
can you merge with main?
Done
sushanthr
approved these changes
Apr 8, 2025
guschmue
approved these changes
Apr 8, 2025
zhaoxul-qti
pushed a commit
to CodeLinaro/onnxruntime
that referenced
this pull request
Apr 17, 2025
This PR adds the flash decoding support to optimization the generation speed when the total sequence length is large. Previously, when the total sequence length is big enough, the softmax and softmax * v shaders will become the bottleneck since it only uses limited gpu cores. In this changes, we add the flash decoding support to split the present key/value based on the total sequence length, then do reduce to get the final result. On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for phi4 static kv cache On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static kv cache Side effect of this PR: It adds two extra buffers to store 1) metadata (max and exp_sum in each split), 2) the splited qkv results with shape [B, N, split_k, H], which increase the memory size. TODO: Ideally, there should only be two shaders, which can also reduce the intermediate memory. The computeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, I meet some issues that when the total sequence length exceeds some value, the result will become garbage. Since I can't resolve it in a short time, leave it in as TODO to fix it in future.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds the flash decoding support to optimization the generation speed when the total sequence length is large. Previously, when the total sequence length is big enough, the softmax and softmax * v shaders will become the bottleneck since it only uses limited gpu cores. In this changes, we add the flash decoding support to split the present key/value based on the total sequence length, then do reduce to get the final result.
On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for phi4 static kv cache
On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static kv cache
Side effect of this PR:
It adds two extra buffers to store 1) metadata (max and exp_sum in each split), 2) the splited qkv results with shape [B, N, split_k, H], which increase the memory size.
TODO:
Ideally, there should only be two shaders, which can also reduce the intermediate memory. The computeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, I meet some issues that when the total sequence length exceeds some value, the result will become garbage. Since I can't resolve it in a short time, leave it in as TODO to fix it in future.