-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add block mask utility support for batches and heads > 1 #130227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130227
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit ed9a3e9 with merge base 6875179 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
aten/src/ATen/native/IndexingUtils.h
Outdated
| } | ||
|
|
||
| static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) { | ||
| static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only used in 3 places, and all of them have allow_int set to True or should have it set to True.
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
This pull request was exported from Phabricator. Differential Revision: D59498662 |
…lexAttention numerics to be as accurate as FA2) (#130250)" This reverts commit 3e48d92. Reverted #130250 on behalf of https://github.com/izaitsevfb due to depends on #130227 which needs to be reverted ([comment](#130250 (comment)))
|
@pytorchbot revert -m "breaks internal builds, please see D59498662" -c ghfirst |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@Chillee your PR has been successfully reverted. |
…0227)" This reverts commit 6413998. Reverted #130227 on behalf of https://github.com/izaitsevfb due to breaks internal builds, please see D59498662 ([comment](#130227 (comment)))
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)" This reverts commit 3e48d92. Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
…orch#130227)" This reverts commit 6413998. Reverted pytorch#130227 on behalf of https://github.com/izaitsevfb due to breaks internal builds, please see D59498662 ([comment](pytorch#130227 (comment)))
…orch#130227)" This reverts commit 6413998. Reverted pytorch#130227 on behalf of https://github.com/izaitsevfb due to breaks internal builds, please see D59498662 ([comment](pytorch#130227 (comment)))
…tion numerics to be as accurate as FA2) (#130250) After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.  Pull Request resolved: #130250 Approved by: https://github.com/drisspg ghstack dependencies: #130227
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)" This reverts commit 3e48d92. Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
…orch#130227)" This reverts commit 6413998. Reverted pytorch#130227 on behalf of https://github.com/izaitsevfb due to breaks internal builds, please see D59498662 ([comment](pytorch#130227 (comment)))
…tion numerics to be as accurate as FA2) (pytorch#130250) Pull Request resolved: pytorch#130250 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#130160, pytorch#130106, pytorch#130224, pytorch#130227
…lexAttention numerics to be as accurate as FA2) (pytorch#130250)" This reverts commit 3e48d92. Reverted pytorch#130250 on behalf of https://github.com/izaitsevfb due to depends on pytorch#130227 which needs to be reverted ([comment](pytorch#130250 (comment)))
…orch#130227)" This reverts commit 6413998. Reverted pytorch#130227 on behalf of https://github.com/izaitsevfb due to breaks internal builds, please see D59498662 ([comment](pytorch#130227 (comment)))
…tion numerics to be as accurate as FA2) (pytorch#130250) After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error.  Pull Request resolved: pytorch#130250 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#130227
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang