Skip to content

Conversation

@Chillee
Copy link
Collaborator

@Chillee Chillee commented Jul 8, 2024

Stack from ghstack (oldest at bottom):

After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for dq our numerical error was 30% higher. I also added a PRESCALE_QK kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.

image

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130250

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1db7a9c with merge base 6875179 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Chillee added a commit that referenced this pull request Jul 8, 2024
ghstack-source-id: 52fecb0
Pull Request resolved: #130250
@github-actions github-actions bot requested a review from ezyang July 8, 2024 16:23
@Chillee Chillee requested review from drisspg and yanboliang July 8, 2024 16:24
@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 8, 2024

@Chillee would FlexAttention be made available as a backend for SDPA? E.g. will this enable FAv2 impl with custom attn_bias? (at least for forward pass)

Is it polishing the triton impl of FAv2? I remember reading in issues that its perf was behind CUDA version of FAv2... And also people complained of slow backward...

)
return subgraph_buffer

def convert_output_node_to_buffer(output):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I find inlined functions really hard to read especially if they are multiple levels of nested deep.. probs why I struggle with PT2 lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like inlined functions particularly in cases like this, because it keeps the definition of the function close to the actual usage.

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments, mostly nits

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@Chillee Chillee added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 8, 2024
@albanD albanD removed their request for review July 8, 2024 22:46
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@Chillee Chillee changed the title Add scale kwarg to FlexAttention Add scale kwarg to FlexAttention (and some changes that get FlexAttention numerics to be as accurate as FA2) Jul 9, 2024
…t FlexAttention numerics to be as accurate as FA2)"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@izaitsevfb
Copy link
Contributor

@pytorchbot revert -m "depends on #130227 which needs to be reverted" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Jul 9, 2024
…lexAttention numerics to be as accurate as FA2) (#130250)"

This reverts commit 3e48d92.

Reverted #130250 on behalf of https://github.com/izaitsevfb due to depends on #130227 which needs to be reverted ([comment](#130250 (comment)))
@pytorchmergebot
Copy link
Collaborator

@Chillee your PR has been successfully reverted.

…t FlexAttention numerics to be as accurate as FA2)"


After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.

![image](https://github.com/pytorch/pytorch/assets/6355099/7b5ff44e-219b-4a05-8a1b-2a0182c01ab2)



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
…t FlexAttention numerics to be as accurate as FA2)"


After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.

![image](https://github.com/pytorch/pytorch/assets/6355099/7b5ff44e-219b-4a05-8a1b-2a0182c01ab2)



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
Chillee added a commit that referenced this pull request Jul 10, 2024
ghstack-source-id: 4c8cd50
Pull Request resolved: #130250
@Chillee
Copy link
Collaborator Author

Chillee commented Jul 10, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

datagero pushed a commit to datagero/pytorch that referenced this pull request Jul 10, 2024
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)"

This reverts commit 3e48d92.

Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@Chillee
Copy link
Collaborator Author

Chillee commented Jul 10, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

datagero pushed a commit to datagero/pytorch that referenced this pull request Jul 10, 2024
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)"

This reverts commit 3e48d92.

Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)"

This reverts commit 3e48d92.

Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…tion numerics to be as accurate as FA2) (pytorch#130250)

After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.

![image](https://github.com/pytorch/pytorch/assets/6355099/7b5ff44e-219b-4a05-8a1b-2a0182c01ab2)

Pull Request resolved: pytorch#130250
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#130227
@github-actions github-actions bot deleted the gh/chillee/319/head branch August 10, 2024 01:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants