-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add persistent+TMA version of Triton mm and addmm #142101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142101
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3bc5de0 with merge base e0bdae7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| # For prologue fusion we check if the underlying template of the choice | ||
| # supports all allowed prologue inputs. If not, we skip this choice in | ||
| # the fusion benchmark. | ||
| # TODO: Remove this check after all Triton templates support prologue fusion. | ||
| # Currently, persistent+TMA Triton template does not due to the TMA-based loads. | ||
| if ( | ||
| not epilogue_fusion | ||
| and hasattr(choice, "allowed_prologue_inps") | ||
| and choice.allowed_prologue_inps != multi_node.allowed_prologue_inps | ||
| ): | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @eellison: this is to selectively skip choices not supporting prologue fusion (like currently the choices from the persistent+TMA template).
| scaled_persistent_mm_kernel_configs = [ | ||
| {"config": (128, 128, 64, 3, 8), "cond": True}, | ||
| {"config": (128, 128, 128, 3, 8), "cond": True}, | ||
| {"config": (128, 128, 128, 4, 8), "cond": True}, | ||
| {"config": (128, 128, 128, 4, 4), "cond": True}, | ||
| {"config": (128, 128, 128, 3, 4), "cond": True}, | ||
| {"config": (128, 128, 128, 5, 4), "cond": True}, | ||
| {"config": (128, 128, 128, 5, 8), "cond": True}, | ||
| {"config": (128, 128, 128, 6, 8), "cond": True}, | ||
| {"config": (128, 128, 64, 4, 8), "cond": True}, | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @drisspg: separated these configs used in _scaled_mm persistent+TMA template-based lowering, as ~half of them OOMs on SMEM for 2-byte dtypes. Kept the ones that don't in persistent_mm_kernel_configs.
| w_inverse_scale, | ||
| bias, | ||
| ) | ||
| with config.patch({"triton.enable_persistent_tma_matmul": True}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, bad mistake on my end thanks for fixing!
torch/_inductor/kernel/mm.py
Outdated
| # based on triton.ops.matmul | ||
| start_pid = tl.program_id(0) | ||
| grid_m = (M + BLOCK_M - 1) // BLOCK_M |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not cdiv for these too?
torch/_inductor/kernel/mm_scaled.py
Outdated
| workspace_arg=get_workspace_arg( | ||
| kwargs["NUM_SMS"], mat_a.get_device() | ||
| workspace_arg=get_tma_workspace_arg( | ||
| num_tma_descriptors=3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we actually only need 2, if you care to update as well, my top of stack had the C stores but that was buggy anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this! Helped me find a nasty bug in the new template code, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! will let Elias comment on the prologue stuff
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great !
| # inductor generates a suffix | ||
| {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} | ||
| acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe in a follow up we can dedup with the scaled version cc @drisspg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah should/want to do this
| [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], | ||
| A.dtype.element_ty, | ||
| ) | ||
| b = tl._experimental_descriptor_load( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont see any mask here. i guess it doesnt support k not divisible by k block ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TMA does support k not divisible by k block (given that k * dtype.itemsize % 16 == 0). Masking happens in the HW doing TMA, with the OOB values set to zero.
| c | ||
| for c in choices | ||
| if re.search( | ||
| config.test_configs.autotune_choice_name_regex, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is it worth functools.lru_caching re.compile() and using that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I imagine these regexes being used exclusively for testing (as in: not in prod / with real models) and being relatively simple (prob a substring or a few separated by |). So not sure how much it's worth precompiling them. Happy to add if you feel it's worth, though.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| ) | ||
| if ki == k_tiles - 1: | ||
| # rematerialize rm and rn to save registers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do you determine if not doing this results in spill/fills?
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the
tuned_mmandtuned_addmmlowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by theuse_triton_tma_templatehelper):The min. hardware and Triton version requirements are met for the TMA support.
The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).
The
config.triton.enable_persistent_tma_matmulis set toTrue.Additional notes:
As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.
The current Triton TMA API (
experimental_device_tensormap_create2d) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).
After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in [Pipeliner] Multi-buffer TMA descriptors triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above.
The configs for the new Triton template in
persistent_mm_kernel_configsare preliminary. We should do more perf exploration and possibly augment the config in a follow-up.This PR is rebased onto and unifies with two related PRs landed previously: Adding lowering to persistent-tma device kernel for _scaled_mm #142045 (some infra unification with the persistent+TMA template for _scaled_mm) and Prologue Fusion #134532 (add possibility to disable prolog fusion for selected choices).
The current Triton TMA API only supports 1D and 2D descriptors (even after [Pipeliner] Multi-buffer TMA descriptors triton-lang/triton#5290, see here). For now, this blocks adding persistent+TMA template for
torch.bmm.Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang