Skip to content

Conversation

@aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Nov 19, 2024

Stack from ghstack (oldest at bottom):

Fixes #140800.

On AMD, backend-specific args like matrix_instr_nonkdim, waves_per_eu and kpack are passed either direclty to the kernel or via triton.Config, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args here. In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.

[ghstack-poisoned]
@aakhundov aakhundov requested a review from zou3519 as a code owner November 19, 2024 21:00
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141062

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 66bf59f with merge base 93aef68 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Nov 19, 2024
Fixes #140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.

ghstack-source-id: 70037b8
Pull Request resolved: #141062
@aakhundov aakhundov added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor labels Nov 19, 2024
@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jataylo pushed a commit to ROCm/pytorch that referenced this pull request Dec 4, 2024
…rch#141062)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.
Pull Request resolved: pytorch#141062
Approved by: https://github.com/oulgen

(cherry picked from commit b740a1b)
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…rch#141062)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.
Pull Request resolved: pytorch#141062
Approved by: https://github.com/oulgen
@github-actions github-actions bot deleted the gh/aakhundov/17/head branch December 20, 2024 02:06
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Jan 31, 2025
… analysis (#141… (#1768)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`,
`waves_per_eu` and `kpack` are passed either direclty to the kernel or
via `triton.Config`, whereas they don't exist as kernel parameters.
Native Triton code handles those excessive args
[here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596).
In this PR, we add similar handling to the TTIR analysis code to avoid
bailing out. Pull Request resolved:
pytorch#141062 Approved by:
https://github.com/oulgen

(cherry picked from commit b740a1b)

Fixes #ISSUE_NUMBER

Co-authored-by: Adnan Akhundov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants