-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add host-side Triton TMA support to Dynamo #137677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137677
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit db1362a with merge base 4a8e493 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This adds Dynamo tracing support to the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). Design notes TBA. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
This adds Dynamo tracing support to the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). Design notes TBA. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
This adds Dynamo tracing support to the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). Design notes TBA. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, let's add some more testing around graph breaks.
|
Is this gonna support pre-allocated descriptors on the host-side ? Creating descriptors on the fly for each kernel launch is very slow as you probably know. |
@mobicham thanks for chiming in! Yes, we want to support host-side TMA descriptor preallocation (and filling) with this PR and a (soon-to-be) follow up adding Inductor support. Along the lines of the |
|
@aakhundov perfect, that's exactly what I am looking for, thank you! |
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes: - Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024. - To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created. - The newly introduced variables have `reconstruct` methods used in case of graph breaks. - The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like. - In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors. - In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required. - JIT Inductor and AOT Inductor support will be implemented in follow-up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
|
@aakhundov has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes: - Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024. - To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created. - The newly introduced variables have `reconstruct` methods used in case of graph breaks. - The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like. - In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors. - In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required. - JIT Inductor and AOT Inductor support will be implemented in follow-up PRs. X-link: pytorch/pytorch#137677 Approved by: https://github.com/zou3519 Reviewed By: clee2000 Differential Revision: D64404928 Pulled By: aakhundov fbshipit-source-id: c812cea3867c55800d5fe213bf07bf21292345e3
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes: - Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024. - Due to Dynamo support implemented in the previous PR, the `tma_descriptor_metadata` dict is delivered to the `triton_kerenl_wrap_` lowering and passed to the `ir.UserDefinedTritonKernel` as additional argument. - Looking into the `tma_descriptor_metadata`, `ir.UserDefinedTritonKernel` substitutes the corresponding `TensorBox` arguments of the kernel (swapped upstream in Dynamo) by the new `ir.TMADescriptor` nodes implementing TMA descriptors in Inductor IR. - `ir.TMADescriptor.__init__` provides the wiring between the upstream underlying `ir.TensorBox` and the downstream `ir.UserDefinedTritonKernel` kernel. In particular, we use `ir.NonOwnedLayout` wrapping `ir.ReinterpretView` to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel). - Via `ir.TMADescriptor.codegen`, the Triton's `create_{1d,2d}_tma_descriptor` function call is codegened in the wrapper (in the host code). - New `TMADescriptorArg` dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA. - AOT Inductor support will be implemented in a follow-up PR. Pull Request resolved: pytorch#137950 Approved by: https://github.com/eellison ghstack dependencies: pytorch#137677
Stack from ghstack (oldest at bottom):
This adds Dynamo tracing support for the host-side Triton TMA API (see
create_2d_tma_descriptorcalls on the host in the Triton tutorial). A few notes:t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...), we add three new variable trackers:DataPtrVariable,CreateTMADescriptorVariable(for the function),TMADescriptorVariable(for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.reconstructmethods used in case of graph breaks.tma_descriptor_metadataextracted from the capturedcreate_{1d,2d}_tma_descriptorcalls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.tt.experimental_descriptor_storeTTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec
Differential Revision: D64404928