Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Aug 27, 2024

Stack from ghstack (oldest at bottom):

This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.

Similar to the store_output api:
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}

And the modification api:

{{ modification(
    subgraph_number=0,
    output_name="post_mod_scores",
    score="qk",
    out="qk"
) | indent_except_first(1) }}

We have:

{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}

Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the pointers by BLOCK_K on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.

There are a couple main use cases for prologue fusion:

Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066

Other notes:

By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the codegen_upcast_to_fp32 option to avoid it, but that will not work for libdevice calls that require fp32. We will need #136778 and dtype-aware codegen to upcast fp16 ops into libdevice calls.

With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in set_subgraph_body by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134532

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 91ef793 with merge base 39cacc1 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

eellison added a commit that referenced this pull request Aug 27, 2024
ghstack-source-id: 8c7d4c0
Pull Request resolved: #134532
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 12, 2024
ghstack-source-id: 49a84d9
Pull Request resolved: #134532
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 13, 2024
ghstack-source-id: 12b7124
Pull Request resolved: #134532
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 15, 2024
ghstack-source-id: ee0f4b2
Pull Request resolved: #134532
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 17, 2024
ghstack-source-id: 67c0f48
Pull Request resolved: #134532
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 23, 2024
ghstack-source-id: 8b318c0
Pull Request resolved: #134532
@eellison eellison marked this pull request as draft September 24, 2024 00:19
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 25, 2024
ghstack-source-id: 92734f0
Pull Request resolved: #134532
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 970a2a8954f731ee818ac1a86dae8b012351f63f returned non-zero exit code 1

Auto-merging torch/_inductor/codecache.py
Auto-merging torch/_inductor/codegen/cpp.py
Auto-merging torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
Auto-merging torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py
Auto-merging torch/_inductor/codegen/simd.py
CONFLICT (content): Merge conflict in torch/_inductor/codegen/simd.py
Auto-merging torch/_inductor/codegen/triton.py
CONFLICT (content): Merge conflict in torch/_inductor/codegen/triton.py
Auto-merging torch/_inductor/ir.py
Auto-merging torch/_inductor/scheduler.py
Auto-merging torch/_inductor/select_algorithm.py
error: could not apply 970a2a8954f... Prologue Fusion
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
eellison added a commit that referenced this pull request Dec 12, 2024
ghstack-source-id: 6e03dc1
Pull Request resolved: #134532
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Dec 16, 2024
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper):

1. The min. hardware and Triton version requirements are met for the TMA support.

2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).

3. The `config.triton.enable_persistent_tma_matmul` is set to `True`.

Additional notes:

1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.

2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.

3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).

4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above.

5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up.

6. This PR is rebased onto and unifies with two related PRs landed previously: #142045 (some infra unification with the persistent+TMA template for _scaled_mm) and #134532 (add possibility to disable prolog fusion for selected choices).

7. The current Triton TMA API only supports 1D and 2D descriptors (even after triton-lang/triton#5290, see [here](https://github.com/triton-lang/triton/blob/9829ce87ccb333a2b264b3a80b39a534bfa865ac/python/triton/language/core.py#L1957)). For now, this blocks adding persistent+TMA template for `torch.bmm`.

Pull Request resolved: #142101
Approved by: https://github.com/drisspg, https://github.com/eellison
@github-actions github-actions bot deleted the gh/eellison/690/head branch January 14, 2025 02:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants