-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add dtype attribute to CSEVariable #136778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136778
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 3 Unrelated FailuresAs of commit ea0a9e6 with merge base 245026a ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
Summary: Pull Request resolved: pytorch#136778 - This diff introduces `dtype` attribute to `TritonCSEVariable` and a dtype propagation helper function to infer dtype from input to output for each op. - There will be a follow-up diff that uses this `dtype` information in `TritonCSEVariable` to perform dtype-aware codegen. Test Plan: CI Differential Revision: D61815079
73548d4 to
8dbd3d0
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
8dbd3d0 to
378da84
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
378da84 to
224c3d8
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
224c3d8 to
8095fe5
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
8095fe5 to
e36cd41
Compare
Summary: Pull Request resolved: pytorch#136778 - This diff introduces `dtype` attribute to `TritonCSEVariable` and a dtype propagation helper function to infer dtype from input to output for each op. - There will be a follow-up diff that uses this `dtype` information in `TritonCSEVariable` to perform dtype-aware codegen. Test Plan: CI Differential Revision: D61815079
e36cd41 to
aa0feae
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
84c05a2 to
9703f6a
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61815079 |
Merge startedYour change will be merged while ignoring the following 0 checks: Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
9 similar comments
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
Merge startedYour change will be merged while ignoring the following 2 checks: pull / linux-focal-cuda12.4-py3.10-gcc9-bazel-test / build-and-test (default, 1, 1, linux.4xlarge.nvidia.gpu), pull / linux-focal-cuda12.1-py3.10-gcc9-bazel-test / build-and-test (default, 1, 1, linux.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchmergebot merge -f "macos runner issues - pushing this one through" |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](https://github.com/pytorch/pytorch/blob/bb03ef7acadf7bbc3287b0ada58e9476eda6e0fe/torch/_inductor/kernel/mm.py#L110-L111) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See #134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need #136778 and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec
[ghstack-poisoned]
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](https://github.com/pytorch/pytorch/blob/bb03ef7acadf7bbc3287b0ada58e9476eda6e0fe/torch/_inductor/kernel/mm.py#L110-L111) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See #134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need #136778 and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec
[ghstack-poisoned]
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](https://github.com/pytorch/pytorch/blob/bb03ef7acadf7bbc3287b0ada58e9476eda6e0fe/torch/_inductor/kernel/mm.py#L110-L111) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See #134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need #136778 and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
Pull Request resolved: #134532
Approved by: https://github.com/jansel
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](https://github.com/pytorch/pytorch/blob/bb03ef7acadf7bbc3287b0ada58e9476eda6e0fe/torch/_inductor/kernel/mm.py#L110-L111) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See #134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need #136778 and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
Pull Request resolved: #134532
Approved by: https://github.com/jansel
Summary:
This diff introduces
dtypeattribute toTritonCSEVariableand a dtype propagation helper function to infer dtype from input to output for each op.There will be a follow-up diff that uses this
dtypeinformation inTritonCSEVariableto perform dtype-aware codegen.Test Plan: CI
Differential Revision: D61815079
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov