Skip to content

Conversation

@aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Oct 10, 2024

Stack from ghstack (oldest at bottom):

This adds Dynamo tracing support for the host-side Triton TMA API (see create_2d_tma_descriptor calls on the host in the Triton tutorial). A few notes:

  • Here we assume the availability of the host-side TMA API added to upstream Triton in [nvidia] Support passing TMA descriptors by-value triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
  • To capture the chain of calls t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...), we add three new variable trackers: DataPtrVariable, CreateTMADescriptorVariable (for the function), TMADescriptorVariable (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.
  • The newly introduced variables have reconstruct methods used in case of graph breaks.
  • The tma_descriptor_metadata extracted from the captured create_{1d,2d}_tma_descriptor calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.
  • In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors.
  • In the Triton kernel mutation analysis pass (in AOTAutograd), we use the tt.experimental_descriptor_store TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.
  • JIT Inductor and AOT Inductor support will be implemented in follow-up PRs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

Differential Revision: D64404928

@aakhundov aakhundov requested a review from zou3519 as a code owner October 10, 2024 01:49
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137677

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit db1362a with merge base 4a8e493 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

This adds Dynamo tracing support to the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)).

Design notes TBA.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
This adds Dynamo tracing support to the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)).

Design notes TBA.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Oct 10, 2024
ghstack-source-id: f1e578b
Pull Request resolved: #137677
This adds Dynamo tracing support to the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)).

Design notes TBA.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Oct 10, 2024
ghstack-source-id: 6814cce
Pull Request resolved: #137677
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, let's add some more testing around graph breaks.

@mobicham
Copy link

Is this gonna support pre-allocated descriptors on the host-side ? Creating descriptors on the fly for each kernel launch is very slow as you probably know.
Thanks!

@aakhundov
Copy link
Contributor Author

Is this gonna support pre-allocated descriptors on the host-side ? Creating descriptors on the fly for each kernel launch is very slow as you probably know. Thanks!

@mobicham thanks for chiming in! Yes, we want to support host-side TMA descriptor preallocation (and filling) with this PR and a (soon-to-be) follow up adding Inductor support. Along the lines of the matmul_tma_persistent in the Triton tutorial #9. Is this what you're referring to?

@mobicham
Copy link

@aakhundov perfect, that's exactly what I am looking for, thank you!

This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Oct 11, 2024
ghstack-source-id: ed14d05
Pull Request resolved: #137677
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.
- The newly introduced variables have `reconstruct` methods used in case of graph breaks.
- The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.
- In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors.
- In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.
- JIT Inductor and AOT Inductor support will be implemented in follow-up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@aakhundov
Copy link
Contributor Author

@aakhundov has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 15, 2024
@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Oct 16, 2024
Summary:
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.
- The newly introduced variables have `reconstruct` methods used in case of graph breaks.
- The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.
- In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors.
- In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.
- JIT Inductor and AOT Inductor support will be implemented in follow-up PRs.

X-link: pytorch/pytorch#137677
Approved by: https://github.com/zou3519

Reviewed By: clee2000

Differential Revision: D64404928

Pulled By: aakhundov

fbshipit-source-id: c812cea3867c55800d5fe213bf07bf21292345e3
smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Oct 18, 2024
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- Due to Dynamo support implemented in the previous PR, the `tma_descriptor_metadata` dict is delivered to the `triton_kerenl_wrap_` lowering and passed to the `ir.UserDefinedTritonKernel` as additional argument.
- Looking into the `tma_descriptor_metadata`, `ir.UserDefinedTritonKernel` substitutes the corresponding `TensorBox` arguments of the kernel (swapped upstream in Dynamo) by the new `ir.TMADescriptor` nodes implementing TMA descriptors in Inductor IR.
- `ir.TMADescriptor.__init__` provides the wiring between the upstream underlying `ir.TensorBox` and the downstream `ir.UserDefinedTritonKernel` kernel. In particular, we use `ir.NonOwnedLayout` wrapping `ir.ReinterpretView` to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel).
- Via `ir.TMADescriptor.codegen`, the Triton's `create_{1d,2d}_tma_descriptor` function call is codegened in the wrapper (in the host code).
- New `TMADescriptorArg` dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA.
- AOT Inductor support will be implemented in a follow-up PR.

Pull Request resolved: pytorch#137950
Approved by: https://github.com/eellison
ghstack dependencies: pytorch#137677
@github-actions github-actions bot deleted the gh/aakhundov/10/head branch November 16, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants