-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add scale kwarg to FlexAttention (and some changes that get FlexAttention numerics to be as accurate as FA2) #130250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130250
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1db7a9c with merge base 6875179 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@Chillee would FlexAttention be made available as a backend for SDPA? E.g. will this enable FAv2 impl with custom Is it polishing the triton impl of FAv2? I remember reading in issues that its perf was behind CUDA version of FAv2... And also people complained of slow backward... |
| ) | ||
| return subgraph_buffer | ||
|
|
||
| def convert_output_node_to_buffer(output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I find inlined functions really hard to read especially if they are multiple levels of nested deep.. probs why I struggle with PT2 lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like inlined functions particularly in cases like this, because it keeps the definition of the function close to the actual usage.
drisspg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments, mostly nits
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…t FlexAttention numerics to be as accurate as FA2)" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@pytorchbot revert -m "depends on #130227 which needs to be reverted" -c ghfirst |
|
@pytorchbot successfully started a revert job. Check the current status here. |
…lexAttention numerics to be as accurate as FA2) (#130250)" This reverts commit 3e48d92. Reverted #130250 on behalf of https://github.com/izaitsevfb due to depends on #130227 which needs to be reverted ([comment](#130250 (comment)))
|
@Chillee your PR has been successfully reverted. |
…t FlexAttention numerics to be as accurate as FA2)" After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.  cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…t FlexAttention numerics to be as accurate as FA2)" After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.  cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)" This reverts commit 3e48d92. Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)" This reverts commit 3e48d92. Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
…tion numerics to be as accurate as FA2) (pytorch#130250) Pull Request resolved: pytorch#130250 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#130160, pytorch#130106, pytorch#130224, pytorch#130227
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)" This reverts commit 3e48d92. Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
…tion numerics to be as accurate as FA2) (pytorch#130250) After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.  Pull Request resolved: pytorch#130250 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#130227
Stack from ghstack (oldest at bottom):
After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for
dqour numerical error was 30% higher. I also added aPRESCALE_QKkernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang