Skip to content

Conversation

@aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Oct 25, 2024

Stack from ghstack (oldest at bottom):

  • (to be filled)

This fixes some leftover typos in
CreateTMADescriptorVariable.call_function (and close).

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138877

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a5e235e with merge base 72ea7ba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@aakhundov aakhundov added the topic: not user facing topic category label Oct 25, 2024
@aakhundov aakhundov requested a review from eellison October 25, 2024 01:23
@aakhundov aakhundov added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 25, 2024
@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

kwargs["block_dim0"] if "block_dim0" in kwargs else args[4],
]
element_size = kwargs["ptr"] if "ptr" in kwargs else args[-1]
element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
element_size = kwargs.get("element_size", args[-1])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And for all the other kwargs accesses with the similar type, should avoid accidental typos of the key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be that args is empty and all arguments are passed via kwargs. Would kwargs.get("element_size", args[-1]) not attempt to evaluate args[-1] first? Example:

>>> args = []
>>> kwargs = {"arg": 1}
>>> kwargs["arg"] if "arg" in kwargs else args[-1]
1
>>> kwargs.get("arg", args[-1])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: list index out of range

@aakhundov
Copy link
Contributor Author

@pytorchbot merge -f

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot merge: error: argument -f/--force: expected one argument

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Try @pytorchbot --help for more info.

@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@Skylion007
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 28, 2024
This adds host-side Triton TMA support to AOTInductor. Notes:

- Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).
- C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions.
- Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen.
- This PR concludes the host-side Triton TMA support in PT2.

Pull Request resolved: #138878
Approved by: https://github.com/desertfire, https://github.com/chenyang78
ghstack dependencies: #138759, #138877
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Oct 29, 2024
This adds host-side Triton TMA support to AOTInductor. Notes:

- Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).
- C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions.
- Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen.
- This PR concludes the host-side Triton TMA support in PT2.

Pull Request resolved: pytorch#138878
Approved by: https://github.com/desertfire, https://github.com/chenyang78
ghstack dependencies: pytorch#138759, pytorch#138877
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
This adds host-side Triton TMA support to AOTInductor. Notes:

- Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).
- C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions.
- Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen.
- This PR concludes the host-side Triton TMA support in PT2.

Pull Request resolved: pytorch#138878
Approved by: https://github.com/desertfire, https://github.com/chenyang78
ghstack dependencies: pytorch#138759, pytorch#138877
@github-actions github-actions bot deleted the gh/aakhundov/13/head branch November 26, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants